Building a car classifier

Let's look at a slightly more interesting visual example than character or digit classification. Classifying cars. We'll use the Cars dataset from Jonathan Krause et al. at Stanford -- 16,185 images in 196 classes:

http://ai.stanford.edu/~jkrause/cars/car_dataset.html

Goal: showing my work

The goal is to show all the steps of exploration and tuning on a new dataset, including false starts, etc. I deliberately didn't read lots of papers on this specific dataset or application, to better simulate working on an unfamiliar problem. I also didn't refactor the code to be prettier or remove things that aren't really needed. If you want to see how to use a particular technique, there are more concise guides out there.

The plan going in:

  • Load the dataset and do preliminary exploration. Since we're starting with a pre-labeled dataset we get to skip all the data cleaning.
  • Do some preprocessing to prepare for building a model.
  • Build a toy model on a small piece of the dataset to get a pipeline working
  • Try to build an SUV/car/other classifier using just these images.
  • Try transfer learning by fine-tuning squeezenet
  • See how much we can gain via data augmentation
  • Try the classifer on new images from the internet...
  • if we want to go crazy, we can scrape a bunch more car photos from the internet and try out unsupervised pre-training. May be too much for this example.

Note: if you try to run this notebook, make sure you have lots of RAM (I have 16GB) -- it isn't very careful about memory use.

Set up the notebook

In [3]:
%load_ext autoreload
%autoreload 2
In [4]:
# system
import os
import glob
import itertools as it
import operator
from collections import defaultdict
from StringIO import StringIO

# other libraries
import numpy as np 
import pandas as pd
import scipy.io  # for loading .mat files
import scipy.misc # for imresize
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
import seaborn as sns
import requests
In [5]:
# my code -- various helpful utilities
from display import visualize_keras_model, plot_training_curves
from helpers import combine_histories
Using TensorFlow backend.
In [6]:
%matplotlib inline
sns.set_style("white")
p = sns.color_palette()

# repeatability:
np.random.seed(42)
In [7]:
data_root = os.path.expanduser("~/data/cars")

Data exploration

I've already downloaded and unzipped the dataset. Let's see how it's organized.

In [8]:
os.listdir(data_root)
Out[8]:
['car_ims', 'car_ims.tgz', 'cars_annos.mat', 'resized_car_ims', 'source.txt']

Let's look at the annotations first. It's a matlab file; luckily scipy has a function to load it.

In [9]:
cars_annos = scipy.io.loadmat(os.path.join(data_root, 'cars_annos.mat'))
In [10]:
cars_annos.keys()
Out[10]:
['annotations', '__version__', '__header__', 'class_names', '__globals__']

What classes do we have?

Class names seems promising. Let's take a look.

In [8]:
cars_annos['class_names']
Out[8]:
array([[array([u'AM General Hummer SUV 2000'], 
      dtype='<U26'),
        array([u'Acura RL Sedan 2012'], 
      dtype='<U19'),
        array([u'Acura TL Sedan 2012'], 
      dtype='<U19'),
        array([u'Acura TL Type-S 2008'], 
      dtype='<U20'),
        array([u'Acura TSX Sedan 2012'], 
      dtype='<U20'),
        array([u'Acura Integra Type R 2001'], 
      dtype='<U25'),
        array([u'Acura ZDX Hatchback 2012'], 
      dtype='<U24'),
        array([u'Aston Martin V8 Vantage Convertible 2012'], 
      dtype='<U40'),
        array([u'Aston Martin V8 Vantage Coupe 2012'], 
      dtype='<U34'),
        array([u'Aston Martin Virage Convertible 2012'], 
      dtype='<U36'),
        array([u'Aston Martin Virage Coupe 2012'], 
      dtype='<U30'),
        array([u'Audi RS 4 Convertible 2008'], 
      dtype='<U26'),
        array([u'Audi A5 Coupe 2012'], 
      dtype='<U18'),
        array([u'Audi TTS Coupe 2012'], 
      dtype='<U19'),
        array([u'Audi R8 Coupe 2012'], 
      dtype='<U18'),
        array([u'Audi V8 Sedan 1994'], 
      dtype='<U18'),
        array([u'Audi 100 Sedan 1994'], 
      dtype='<U19'),
        array([u'Audi 100 Wagon 1994'], 
      dtype='<U19'),
        array([u'Audi TT Hatchback 2011'], 
      dtype='<U22'),
        array([u'Audi S6 Sedan 2011'], 
      dtype='<U18'),
        array([u'Audi S5 Convertible 2012'], 
      dtype='<U24'),
        array([u'Audi S5 Coupe 2012'], 
      dtype='<U18'),
        array([u'Audi S4 Sedan 2012'], 
      dtype='<U18'),
        array([u'Audi S4 Sedan 2007'], 
      dtype='<U18'),
        array([u'Audi TT RS Coupe 2012'], 
      dtype='<U21'),
        array([u'BMW ActiveHybrid 5 Sedan 2012'], 
      dtype='<U29'),
        array([u'BMW 1 Series Convertible 2012'], 
      dtype='<U29'),
        array([u'BMW 1 Series Coupe 2012'], 
      dtype='<U23'),
        array([u'BMW 3 Series Sedan 2012'], 
      dtype='<U23'),
        array([u'BMW 3 Series Wagon 2012'], 
      dtype='<U23'),
        array([u'BMW 6 Series Convertible 2007'], 
      dtype='<U29'),
        array([u'BMW X5 SUV 2007'], 
      dtype='<U15'),
        array([u'BMW X6 SUV 2012'], 
      dtype='<U15'),
        array([u'BMW M3 Coupe 2012'], 
      dtype='<U17'),
        array([u'BMW M5 Sedan 2010'], 
      dtype='<U17'),
        array([u'BMW M6 Convertible 2010'], 
      dtype='<U23'),
        array([u'BMW X3 SUV 2012'], 
      dtype='<U15'),
        array([u'BMW Z4 Convertible 2012'], 
      dtype='<U23'),
        array([u'Bentley Continental Supersports Conv. Convertible 2012'], 
      dtype='<U54'),
        array([u'Bentley Arnage Sedan 2009'], 
      dtype='<U25'),
        array([u'Bentley Mulsanne Sedan 2011'], 
      dtype='<U27'),
        array([u'Bentley Continental GT Coupe 2012'], 
      dtype='<U33'),
        array([u'Bentley Continental GT Coupe 2007'], 
      dtype='<U33'),
        array([u'Bentley Continental Flying Spur Sedan 2007'], 
      dtype='<U42'),
        array([u'Bugatti Veyron 16.4 Convertible 2009'], 
      dtype='<U36'),
        array([u'Bugatti Veyron 16.4 Coupe 2009'], 
      dtype='<U30'),
        array([u'Buick Regal GS 2012'], 
      dtype='<U19'),
        array([u'Buick Rainier SUV 2007'], 
      dtype='<U22'),
        array([u'Buick Verano Sedan 2012'], 
      dtype='<U23'),
        array([u'Buick Enclave SUV 2012'], 
      dtype='<U22'),
        array([u'Cadillac CTS-V Sedan 2012'], 
      dtype='<U25'),
        array([u'Cadillac SRX SUV 2012'], 
      dtype='<U21'),
        array([u'Cadillac Escalade EXT Crew Cab 2007'], 
      dtype='<U35'),
        array([u'Chevrolet Silverado 1500 Hybrid Crew Cab 2012'], 
      dtype='<U45'),
        array([u'Chevrolet Corvette Convertible 2012'], 
      dtype='<U35'),
        array([u'Chevrolet Corvette ZR1 2012'], 
      dtype='<U27'),
        array([u'Chevrolet Corvette Ron Fellows Edition Z06 2007'], 
      dtype='<U47'),
        array([u'Chevrolet Traverse SUV 2012'], 
      dtype='<U27'),
        array([u'Chevrolet Camaro Convertible 2012'], 
      dtype='<U33'),
        array([u'Chevrolet HHR SS 2010'], 
      dtype='<U21'),
        array([u'Chevrolet Impala Sedan 2007'], 
      dtype='<U27'),
        array([u'Chevrolet Tahoe Hybrid SUV 2012'], 
      dtype='<U31'),
        array([u'Chevrolet Sonic Sedan 2012'], 
      dtype='<U26'),
        array([u'Chevrolet Express Cargo Van 2007'], 
      dtype='<U32'),
        array([u'Chevrolet Avalanche Crew Cab 2012'], 
      dtype='<U33'),
        array([u'Chevrolet Cobalt SS 2010'], 
      dtype='<U24'),
        array([u'Chevrolet Malibu Hybrid Sedan 2010'], 
      dtype='<U34'),
        array([u'Chevrolet TrailBlazer SS 2009'], 
      dtype='<U29'),
        array([u'Chevrolet Silverado 2500HD Regular Cab 2012'], 
      dtype='<U43'),
        array([u'Chevrolet Silverado 1500 Classic Extended Cab 2007'], 
      dtype='<U50'),
        array([u'Chevrolet Express Van 2007'], 
      dtype='<U26'),
        array([u'Chevrolet Monte Carlo Coupe 2007'], 
      dtype='<U32'),
        array([u'Chevrolet Malibu Sedan 2007'], 
      dtype='<U27'),
        array([u'Chevrolet Silverado 1500 Extended Cab 2012'], 
      dtype='<U42'),
        array([u'Chevrolet Silverado 1500 Regular Cab 2012'], 
      dtype='<U41'),
        array([u'Chrysler Aspen SUV 2009'], 
      dtype='<U23'),
        array([u'Chrysler Sebring Convertible 2010'], 
      dtype='<U33'),
        array([u'Chrysler Town and Country Minivan 2012'], 
      dtype='<U38'),
        array([u'Chrysler 300 SRT-8 2010'], 
      dtype='<U23'),
        array([u'Chrysler Crossfire Convertible 2008'], 
      dtype='<U35'),
        array([u'Chrysler PT Cruiser Convertible 2008'], 
      dtype='<U36'),
        array([u'Daewoo Nubira Wagon 2002'], 
      dtype='<U24'),
        array([u'Dodge Caliber Wagon 2012'], 
      dtype='<U24'),
        array([u'Dodge Caliber Wagon 2007'], 
      dtype='<U24'),
        array([u'Dodge Caravan Minivan 1997'], 
      dtype='<U26'),
        array([u'Dodge Ram Pickup 3500 Crew Cab 2010'], 
      dtype='<U35'),
        array([u'Dodge Ram Pickup 3500 Quad Cab 2009'], 
      dtype='<U35'),
        array([u'Dodge Sprinter Cargo Van 2009'], 
      dtype='<U29'),
        array([u'Dodge Journey SUV 2012'], 
      dtype='<U22'),
        array([u'Dodge Dakota Crew Cab 2010'], 
      dtype='<U26'),
        array([u'Dodge Dakota Club Cab 2007'], 
      dtype='<U26'),
        array([u'Dodge Magnum Wagon 2008'], 
      dtype='<U23'),
        array([u'Dodge Challenger SRT8 2011'], 
      dtype='<U26'),
        array([u'Dodge Durango SUV 2012'], 
      dtype='<U22'),
        array([u'Dodge Durango SUV 2007'], 
      dtype='<U22'),
        array([u'Dodge Charger Sedan 2012'], 
      dtype='<U24'),
        array([u'Dodge Charger SRT-8 2009'], 
      dtype='<U24'),
        array([u'Eagle Talon Hatchback 1998'], 
      dtype='<U26'),
        array([u'FIAT 500 Abarth 2012'], 
      dtype='<U20'),
        array([u'FIAT 500 Convertible 2012'], 
      dtype='<U25'),
        array([u'Ferrari FF Coupe 2012'], 
      dtype='<U21'),
        array([u'Ferrari California Convertible 2012'], 
      dtype='<U35'),
        array([u'Ferrari 458 Italia Convertible 2012'], 
      dtype='<U35'),
        array([u'Ferrari 458 Italia Coupe 2012'], 
      dtype='<U29'),
        array([u'Fisker Karma Sedan 2012'], 
      dtype='<U23'),
        array([u'Ford F-450 Super Duty Crew Cab 2012'], 
      dtype='<U35'),
        array([u'Ford Mustang Convertible 2007'], 
      dtype='<U29'),
        array([u'Ford Freestar Minivan 2007'], 
      dtype='<U26'),
        array([u'Ford Expedition EL SUV 2009'], 
      dtype='<U27'),
        array([u'Ford Edge SUV 2012'], 
      dtype='<U18'),
        array([u'Ford Ranger SuperCab 2011'], 
      dtype='<U25'),
        array([u'Ford GT Coupe 2006'], 
      dtype='<U18'),
        array([u'Ford F-150 Regular Cab 2012'], 
      dtype='<U27'),
        array([u'Ford F-150 Regular Cab 2007'], 
      dtype='<U27'),
        array([u'Ford Focus Sedan 2007'], 
      dtype='<U21'),
        array([u'Ford E-Series Wagon Van 2012'], 
      dtype='<U28'),
        array([u'Ford Fiesta Sedan 2012'], 
      dtype='<U22'),
        array([u'GMC Terrain SUV 2012'], 
      dtype='<U20'),
        array([u'GMC Savana Van 2012'], 
      dtype='<U19'),
        array([u'GMC Yukon Hybrid SUV 2012'], 
      dtype='<U25'),
        array([u'GMC Acadia SUV 2012'], 
      dtype='<U19'),
        array([u'GMC Canyon Extended Cab 2012'], 
      dtype='<U28'),
        array([u'Geo Metro Convertible 1993'], 
      dtype='<U26'),
        array([u'HUMMER H3T Crew Cab 2010'], 
      dtype='<U24'),
        array([u'HUMMER H2 SUT Crew Cab 2009'], 
      dtype='<U27'),
        array([u'Honda Odyssey Minivan 2012'], 
      dtype='<U26'),
        array([u'Honda Odyssey Minivan 2007'], 
      dtype='<U26'),
        array([u'Honda Accord Coupe 2012'], 
      dtype='<U23'),
        array([u'Honda Accord Sedan 2012'], 
      dtype='<U23'),
        array([u'Hyundai Veloster Hatchback 2012'], 
      dtype='<U31'),
        array([u'Hyundai Santa Fe SUV 2012'], 
      dtype='<U25'),
        array([u'Hyundai Tucson SUV 2012'], 
      dtype='<U23'),
        array([u'Hyundai Veracruz SUV 2012'], 
      dtype='<U25'),
        array([u'Hyundai Sonata Hybrid Sedan 2012'], 
      dtype='<U32'),
        array([u'Hyundai Elantra Sedan 2007'], 
      dtype='<U26'),
        array([u'Hyundai Accent Sedan 2012'], 
      dtype='<U25'),
        array([u'Hyundai Genesis Sedan 2012'], 
      dtype='<U26'),
        array([u'Hyundai Sonata Sedan 2012'], 
      dtype='<U25'),
        array([u'Hyundai Elantra Touring Hatchback 2012'], 
      dtype='<U38'),
        array([u'Hyundai Azera Sedan 2012'], 
      dtype='<U24'),
        array([u'Infiniti G Coupe IPL 2012'], 
      dtype='<U25'),
        array([u'Infiniti QX56 SUV 2011'], 
      dtype='<U22'),
        array([u'Isuzu Ascender SUV 2008'], 
      dtype='<U23'),
        array([u'Jaguar XK XKR 2012'], 
      dtype='<U18'),
        array([u'Jeep Patriot SUV 2012'], 
      dtype='<U21'),
        array([u'Jeep Wrangler SUV 2012'], 
      dtype='<U22'),
        array([u'Jeep Liberty SUV 2012'], 
      dtype='<U21'),
        array([u'Jeep Grand Cherokee SUV 2012'], 
      dtype='<U28'),
        array([u'Jeep Compass SUV 2012'], 
      dtype='<U21'),
        array([u'Lamborghini Reventon Coupe 2008'], 
      dtype='<U31'),
        array([u'Lamborghini Aventador Coupe 2012'], 
      dtype='<U32'),
        array([u'Lamborghini Gallardo LP 570-4 Superleggera 2012'], 
      dtype='<U47'),
        array([u'Lamborghini Diablo Coupe 2001'], 
      dtype='<U29'),
        array([u'Land Rover Range Rover SUV 2012'], 
      dtype='<U31'),
        array([u'Land Rover LR2 SUV 2012'], 
      dtype='<U23'),
        array([u'Lincoln Town Car Sedan 2011'], 
      dtype='<U27'),
        array([u'MINI Cooper Roadster Convertible 2012'], 
      dtype='<U37'),
        array([u'Maybach Landaulet Convertible 2012'], 
      dtype='<U34'),
        array([u'Mazda Tribute SUV 2011'], 
      dtype='<U22'),
        array([u'McLaren MP4-12C Coupe 2012'], 
      dtype='<U26'),
        array([u'Mercedes-Benz 300-Class Convertible 1993'], 
      dtype='<U40'),
        array([u'Mercedes-Benz C-Class Sedan 2012'], 
      dtype='<U32'),
        array([u'Mercedes-Benz SL-Class Coupe 2009'], 
      dtype='<U33'),
        array([u'Mercedes-Benz E-Class Sedan 2012'], 
      dtype='<U32'),
        array([u'Mercedes-Benz S-Class Sedan 2012'], 
      dtype='<U32'),
        array([u'Mercedes-Benz Sprinter Van 2012'], 
      dtype='<U31'),
        array([u'Mitsubishi Lancer Sedan 2012'], 
      dtype='<U28'),
        array([u'Nissan Leaf Hatchback 2012'], 
      dtype='<U26'),
        array([u'Nissan NV Passenger Van 2012'], 
      dtype='<U28'),
        array([u'Nissan Juke Hatchback 2012'], 
      dtype='<U26'),
        array([u'Nissan 240SX Coupe 1998'], 
      dtype='<U23'),
        array([u'Plymouth Neon Coupe 1999'], 
      dtype='<U24'),
        array([u'Porsche Panamera Sedan 2012'], 
      dtype='<U27'),
        array([u'Ram C/V Cargo Van Minivan 2012'], 
      dtype='<U30'),
        array([u'Rolls-Royce Phantom Drophead Coupe Convertible 2012'], 
      dtype='<U51'),
        array([u'Rolls-Royce Ghost Sedan 2012'], 
      dtype='<U28'),
        array([u'Rolls-Royce Phantom Sedan 2012'], 
      dtype='<U30'),
        array([u'Scion xD Hatchback 2012'], 
      dtype='<U23'),
        array([u'Spyker C8 Convertible 2009'], 
      dtype='<U26'),
        array([u'Spyker C8 Coupe 2009'], 
      dtype='<U20'),
        array([u'Suzuki Aerio Sedan 2007'], 
      dtype='<U23'),
        array([u'Suzuki Kizashi Sedan 2012'], 
      dtype='<U25'),
        array([u'Suzuki SX4 Hatchback 2012'], 
      dtype='<U25'),
        array([u'Suzuki SX4 Sedan 2012'], 
      dtype='<U21'),
        array([u'Tesla Model S Sedan 2012'], 
      dtype='<U24'),
        array([u'Toyota Sequoia SUV 2012'], 
      dtype='<U23'),
        array([u'Toyota Camry Sedan 2012'], 
      dtype='<U23'),
        array([u'Toyota Corolla Sedan 2012'], 
      dtype='<U25'),
        array([u'Toyota 4Runner SUV 2012'], 
      dtype='<U23'),
        array([u'Volkswagen Golf Hatchback 2012'], 
      dtype='<U30'),
        array([u'Volkswagen Golf Hatchback 1991'], 
      dtype='<U30'),
        array([u'Volkswagen Beetle Hatchback 2012'], 
      dtype='<U32'),
        array([u'Volvo C30 Hatchback 2012'], 
      dtype='<U24'),
        array([u'Volvo 240 Sedan 1993'], 
      dtype='<U20'),
        array([u'Volvo XC90 SUV 2007'], 
      dtype='<U19'),
        array([u'smart fortwo Convertible 2012'], 
      dtype='<U29')]], dtype=object)

Funny structure. Let's turn it into a flat list.

In [11]:
classes = [x[0] for x in cars_annos['class_names'][0]]
classes[:10]
Out[11]:
[u'AM General Hummer SUV 2000',
 u'Acura RL Sedan 2012',
 u'Acura TL Sedan 2012',
 u'Acura TL Type-S 2008',
 u'Acura TSX Sedan 2012',
 u'Acura Integra Type R 2001',
 u'Acura ZDX Hatchback 2012',
 u'Aston Martin V8 Vantage Convertible 2012',
 u'Aston Martin V8 Vantage Coupe 2012',
 u'Aston Martin Virage Convertible 2012']

Looks good. Now let's split into (brand, model, type, year tuples). This will require a bit of manual munging...

In [12]:
first, years = zip(*[(s[:-5], s[-4:])
                     for s in classes])
In [13]:
# Check years...
sorted(pd.unique(years))
Out[13]:
[u'1991',
 u'1993',
 u'1994',
 u'1997',
 u'1998',
 u'1999',
 u'2000',
 u'2001',
 u'2002',
 u'2006',
 u'2007',
 u'2008',
 u'2009',
 u'2010',
 u'2011',
 u'2012']
In [14]:
first[:5]
Out[14]:
(u'AM General Hummer SUV',
 u'Acura RL Sedan',
 u'Acura TL Sedan',
 u'Acura TL Type-S',
 u'Acura TSX Sedan')
In [15]:
# Ok, let's pull out the car types
models, car_types = zip(*[(s.split()[:-1], s.split()[-1])
                         for s in first])
In [16]:
pd.unique(car_types)
Out[16]:
array([u'SUV', u'Sedan', u'Type-S', u'R', u'Hatchback', u'Convertible',
       u'Coupe', u'Wagon', u'GS', u'Cab', u'ZR1', u'Z06', u'SS', u'Van',
       u'Minivan', u'SRT-8', u'SRT8', u'Abarth', u'SuperCab', u'IPL',
       u'XKR', u'Superleggera'], dtype=object)

What do the images look like?

Hmm. Car types are a bit messy -- many are multi-word. We'll want to clean this up later -- I don't know cars well enough to classify without seeing example images, so let's go to that.

In [17]:
images_dir = os.path.join(data_root, "car_ims")
len(os.listdir(images_dir))
Out[17]:
16185

Ok, so we have our 16K images. Let's look at a few random ones.

In [18]:
image_paths[12]
Out[18]:
'000013.jpg'
In [19]:
image_paths = os.listdir(images_dir)
for i in [12, 35, 3600, 12345]:
    img = Image.open(os.path.join(images_dir, image_paths[i]))
    size = img.size # save, since we're about to change it
    img.thumbnail((128, 128)) # mostly to make notebook file smaller
    fig, ax = plt.subplots(figsize=(3,2))
    ax.imshow(img)
    ax.set_title('shape: '+ str(size))
    ax.grid(False)

Ok, we can read the images. They appear to be of vastly different sizes. Let's take a closer look.

In [20]:
shapes = [Image.open(os.path.join(images_dir, path)).size
          for path in image_paths]

No errors! It's nice to have clean data...

In [21]:
widths = [s[0] for s in shapes]
aspect_ratios = [s[0]/float(s[1]) for s in shapes]
In [78]:
fig, ax = plt.subplots(figsize=(4,2.5))
ax.set_title("Image widths")
ax.hist(widths, bins=30)
sns.despine(fig)
In [79]:
fig, ax = plt.subplots(figsize=(4,2.5))
ax.set_title("Aspect ratios")
ax.hist(aspect_ratios, bins=30);
sns.despine(fig)

Essentially all images are in landscape orientation, and aren't too big--less than 1500px wide. A significant number are pretty small -- just a few hundred pixels. That's probably ok for us -- we'll want scale down for performance reasons anyway.

Mapping classes and images

Ok, now let's get the classes for each image -- we need to look at the annotations part of the dict...

In [34]:
len(cars_annos['annotations'])
Out[34]:
1
In [35]:
len(cars_annos['annotations'][0])
Out[35]:
16185
In [36]:
cars_annos['annotations'][0][0]
Out[36]:
([u'car_ims/000001.jpg'], [[112]], [[7]], [[853]], [[717]], [[1]], [[0]])
In [37]:
# what are the fields?
cars_annos['annotations'].dtype
Out[37]:
dtype([('relative_im_path', 'O'), ('bbox_x1', 'O'), ('bbox_y1', 'O'), ('bbox_x2', 'O'), ('bbox_y2', 'O'), ('class', 'O'), ('test', 'O')])
In [24]:
# get rid of the nested arrays
from collections import namedtuple
Example = namedtuple('Example',
                     ['rel_path', 'x1', 'y1', 'x2','y2','cls','test'])
# silly nested nested lists...
examples = [Example(*[a.flatten()[0] for a in x])
           for x in cars_annos['annotations'][0]]
In [39]:
examples[0]
Out[39]:
Example(rel_path=u'car_ims/000001.jpg', x1=112, y1=7, x2=853, y2=717, cls=1, test=0)

Ok, that worked. Now let's look at a couple of images from each class.

In [26]:
key_fn = operator.attrgetter('cls')
by_class = {} # key -> lst
for cls, group in it.groupby(sorted(examples, key=key_fn), key_fn):
    by_class[cls] = list(group) 
In [27]:
sorted(by_class.keys())[:5]  # note: classes start at 1, not 0
Out[27]:
[1, 2, 3, 4, 5]
In [42]:
# 196 = 14*14
fig, plots = plt.subplots(14,14, sharex='all', sharey='all',
                         figsize=(28,28))
for i in range(196):
    # read the image
    rel_path = by_class[i+1][0].rel_path
    # Note: rel_paths include 'car_ims/'
    img = Image.open(os.path.join(data_root, rel_path))
    img = img.resize((100,100))
    plots[i // 14, i % 14].axis('off')
    plots[i // 14, i % 14].imshow(img)

Pretty! :)

In [28]:
# Which classes have the most examples?
counts = sorted([(k, classes[k-1], len(by_class[k]))
                 for k in by_class.keys()], 
                reverse=True,
               key=operator.itemgetter(2))
In [80]:
print("5 most frequent:\n")
print("\n".join(map(str, counts[:5])))
print("\n5 least frequent:\n")
print("\n".join(map(str, counts[-5:])))
fig, ax = plt.subplots(figsize=(4,2.5))

ax.plot(range(len(counts)), [x[2] for x in counts])
ax.set_title("Sorted counts by class")
sns.despine(fig)
5 most frequent:

(119, u'GMC Savana Van 2012', 136)
(79, u'Chrysler 300 SRT-8 2010', 97)
(161, u'Mercedes-Benz 300-Class Convertible 1993', 96)
(167, u'Mitsubishi Lancer Sedan 2012', 95)
(56, u'Chevrolet Corvette ZR1 2012', 93)

5 least frequent:

(175, u'Rolls-Royce Phantom Drophead Coupe Convertible 2012', 61)
(64, u'Chevrolet Express Cargo Van 2007', 59)
(158, u'Maybach Landaulet Convertible 2012', 58)
(99, u'FIAT 500 Abarth 2012', 55)
(136, u'Hyundai Accent Sedan 2012', 48)

Preprocessing the images

Ok, now that we have a sense of what the dataset looks like, let's do some preprocessing.

  • We'll want fixed size (227x227) images. Will need to decide whether to just resize, use the provided bounding boxes, resize + crop, etc. (227 is the magic size used by imagenet -- we'll use their classifiers later)
  • may also want to convert to grayscale.
  • For now, let's just resize all to 227x227. It'll screw up the aspect ratios a bit, but is a reasonable place to start. How much space will that take if we store the images uncompressed?
In [45]:
# 12 bytes per pixel. Size in MB.
16100 * 227 * 227 * 12 / 1024 / 1024
Out[45]:
9494

9.5G is a bit much. Let's keep in jpg for now, and decode on the fly. (Could try both and see what's faster...)

In [31]:
resized_path = os.path.join(data_root,'resized_car_ims')
if not os.path.exists(resized_path):
    os.mkdir(resized_path)    
In [133]:
# Resize all the things. Will take a little while.
for fname in image_paths:
    new_path = os.path.join(resized_path, fname)
    # skip if already exists. 
    # (Blow away whole directory if there's a problem, or to change
    # resize strategy)
    if not os.path.exists(new_path):
        img = Image.open(os.path.join(images_dir, fname))
        img = img.resize((227,227))
        img.save(new_path)
In [32]:
# double check
len(os.listdir(resized_path))
Out[32]:
16185
In [48]:
!du -shc {resized_path} {images_dir}
219M	/Users/shnayder/data/cars/resized_car_ims
1.9G	/Users/shnayder/data/cars/car_ims
2.1G	total

Get a smaller dataset to process. Will make for faster processing.

Solve a toy problem

Ok, now we have standard size images and roughly understand what things look like. Let's start with a toy problem to get our code working.

Problem 1: distinguishing Hummers from the Acura sedan.

These happen to be the first two classes, and it seems to be a relatively simple task -- certainly for humans.

In [83]:
# here are our classes
classes[:2]
Out[83]:
[u'AM General Hummer SUV 2000', u'Acura RL Sedan 2012']
In [87]:
# and here are a few images for each
fig, plots = plt.subplots(2,6, sharex='all', sharey='all',
                         figsize=(24,6))
for i in range(2):
    for j in range(6):
        # read the image
        rel_path = by_class[i+1][j].rel_path
        # Note: rel_paths include 'car_ims/'
        img = Image.open(os.path.join(data_root, rel_path))
        img = img.resize((100,100))
        plots[i, j].axis('off')
        plots[i, j].imshow(img)
In [95]:
# How many examples do we have?
fig, ax = plt.subplots()
ax.bar([1,2], [len(by_class[i]) for i in [1,2]], tick_label=["Hummer", "Acura"], align="center")
sns.despine(fig)

Step 1: load and prepare the data

Let's get our training, validation, and test data prepared. We'll just keep it in-memory, at least for now.

In [33]:
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils

# https://github.com/fchollet/keras/issues/4499
from keras.layers.core import K
from keras.callbacks import TensorBoard

# for name scopes to make TensorBoard look prettier 
# (doesn't work well yet as of Keras 1.x -- maybe better in 2.x)
import tensorflow as tf 
In [34]:
def gray_to_rgb(im):
    """
    Noticed (due to array projection error in code below) that there is at least
    one grayscale image in the dataset.
    We'll use this to convert.
    """
    w, h = im.shape
    ret = np.empty((w,h,3), dtype=np.uint8)
    ret[:,:,0] = im
    ret[:,:,1] = im
    ret[:,:,2] = im
    return ret
In [35]:
def load_examples(by_class, cls, limit=None):
    """
    Load examples for a class. Ignores test/train distinction -- 
    we'll do our own train/validation/test split later.
    
    Args:
        by_class: our above dict -- class_id -> [Example()]
        cls: which class to load
        limit: if not None, only load this many images.
        
    Returns:
        list of (X,y) tuples, one for each image.
            X: 227x227x3 ndarray of type uint8
            Y: class_id (will be equal to cls)
    """
    res = []
    to_load = by_class[cls]
    if limit:
        to_load = to_load[:limit]

    for ex in to_load:
        # load the resized image!
        img_path = os.path.join(data_root, 
                        ex.rel_path.replace('car_ims', 'resized_car_ims'))
        img = mpimg.imread(img_path)
        # handle any grayscale images
        if len(img.shape) == 2:
            img = gray_to_rgb(img)
        res.append((img, cls))
    return res

Split into training, validation, and test sets

Train network on training data, tune parameters using validation set, and finally see how well we did on the test set.

In [36]:
def split_examples(xs, valid_frac, test_frac):
    """
    Randomly splits the xs array into train, valid, test, with specified 
    percentages. Rounds down.
    
    Returns:
        (train, valid, test)
    """
    assert valid_frac + test_frac < 1
    
    n = len(xs)
    valid = int(valid_frac * n)
    test = int(test_frac * n)
    train = n - valid - test
    
    # don't change passed-in list
    shuffled = xs[:]
    np.random.shuffle(shuffled)

    return (shuffled[:train], 
            shuffled[train:train + valid], 
            shuffled[train + valid:])

# quick test
split_examples(range(10), 0.2, 0.4)
Out[36]:
([8, 1, 5, 0], [7, 2], [9, 4, 3, 6])
In [53]:
valid_frac = 0.2
test_frac = 0.2
# load the Hummer and acura images
(train, valid, test) = split_examples(load_examples(by_class, 1),
                                     valid_frac, test_frac)
(train2, valid2, test2) = split_examples(load_examples(by_class, 2),
                                     valid_frac, test_frac)

train.extend(train2)
valid.extend(valid2)
test.extend(test2)

# ...and shuffle to make training work better.
np.random.shuffle(train)
np.random.shuffle(valid)
np.random.shuffle(test)
In [54]:
# We have lists of (X,Y) tuples. Let's unzip into lists of Xs and Ys.
X_train, Y_train = zip(*train)
X_valid, Y_valid = zip(*valid)
X_test, Y_test = zip(*test)

# and turn into np arrays of the right dimension.
def convert_X(xs):
    '''
    Take list of (w,h,3) images.
    Turn into an np array, change type to float32.
    '''
    return np.array(xs).astype('float32')
    
X_train = convert_X(X_train)
X_valid = convert_X(X_valid)
X_test = convert_X(X_test)
In [56]:
# Note that despite lots of Keras examples online that want the data to
# have shape 3,w,h, we actually want w,h,3 when using the TensorFlow
# backend.
X_train.shape
Out[56]:
(95, 227, 227, 3)

Convert labels to one-hot.

Two notes here:

  1. One hot means converting numbers 0-K into vectors of length K+1: 0 becomes 1,0,0,0; 1 becomes 0,1,0,0; etc.
  2. Because we only have two classes, could actually use a simple binary classifier, but I'm using one-hot because we'll want to classify into more classes later.

First, we need to make sure they're sequential starting from 0.

In [57]:
all_ys = sorted(set(Y_train).union(set(Y_valid)).union(set(Y_test)))
n_classes = len(all_ys)
mapping = dict(zip(all_ys, range(n_classes)))
mapping
Out[57]:
{1: 0, 2: 1}
In [58]:
def convert_Y(ys, mapping):
    '''
    Convert to np array, make class values sequential from 0, 
    and make one-hot
    '''
    ret = np.array([mapping[y] for y in ys])
    n_classes = len(mapping)
    return np_utils.to_categorical(ret, n_classes)

Y_train = convert_Y(Y_train, mapping)
Y_valid = convert_Y(Y_valid, mapping)
Y_test = convert_Y(Y_test, mapping)
In [59]:
Y_train.shape
Out[59]:
(95, 2)

Toy approach: logistic classifier

Let's build a simple logistic classifier. We'll convert to grayscale to reduce number of params by factor of 3.

In [60]:
X_train[0].mean(axis=2).shape
Out[60]:
(227, 227)
In [ ]:
def normalize_to_gray(xs):
    """Convert vals to [0,1], reshape into flat vector.
    Just averaging the three channels, even though that's not the optimal way to do it.
    """
    ret = (xs / 255.0).mean(axis=3)
    return np.reshape(ret, (-1, 227*227))

X_train_gray = normalize_to_gray(X_train)
X_valid_gray = normalize_to_gray(X_valid)
X_test_gray = normalize_to_gray(X_test)
In [62]:
X_train_gray[0].shape
Out[62]:
(51529,)
In [63]:
# (1-x) because grayscale images are backwards--1 is black, it appears
plt.imshow((1 - X_train_gray[0]).reshape((227,227)))
Out[63]:
<matplotlib.image.AxesImage at 0x133d42610>

First "network" -- one layer

In [64]:
def logistic_model():
    model = Sequential()
    # should really use lower resolution if we're making this silly
    # of a model
    model.add(Dense(output_dim=2, input_dim=227*227))
    model.add(Activation('softmax'))

    return model

model = logistic_model()
model.compile(loss='categorical_crossentropy',
              optimizer='adadelta',
              metrics=['accuracy'])
In [65]:
Y_train.shape
Out[65]:
(95, 2)
In [66]:
history = model.fit(X_train_gray, Y_train,
                    batch_size=16, nb_epoch=100, verbose=1,
         validation_data=(X_valid_gray, Y_valid))
Train on 95 samples, validate on 29 samples
Epoch 1/100
95/95 [==============================] - 0s - loss: 7.2522 - acc: 0.3579 - val_loss: 5.0068 - val_acc: 0.5862
Epoch 2/100
95/95 [==============================] - 0s - loss: 4.9483 - acc: 0.5263 - val_loss: 2.0935 - val_acc: 0.4138
Epoch 3/100
95/95 [==============================] - 0s - loss: 3.2909 - acc: 0.5684 - val_loss: 7.3005 - val_acc: 0.4138
Epoch 4/100
95/95 [==============================] - 0s - loss: 2.6830 - acc: 0.6211 - val_loss: 2.9015 - val_acc: 0.6897
Epoch 5/100
95/95 [==============================] - 0s - loss: 2.4143 - acc: 0.6211 - val_loss: 2.2550 - val_acc: 0.7586
Epoch 6/100
95/95 [==============================] - 0s - loss: 3.4747 - acc: 0.5684 - val_loss: 1.4511 - val_acc: 0.7241
Epoch 7/100
95/95 [==============================] - 0s - loss: 1.2240 - acc: 0.7579 - val_loss: 1.7592 - val_acc: 0.5172
Epoch 8/100
95/95 [==============================] - 0s - loss: 4.9564 - acc: 0.4632 - val_loss: 3.5317 - val_acc: 0.6897
Epoch 9/100
95/95 [==============================] - 0s - loss: 1.9063 - acc: 0.7053 - val_loss: 3.0339 - val_acc: 0.5172
Epoch 10/100
95/95 [==============================] - 0s - loss: 1.4162 - acc: 0.7579 - val_loss: 1.0628 - val_acc: 0.6552
Epoch 11/100
95/95 [==============================] - 0s - loss: 0.7367 - acc: 0.8105 - val_loss: 2.8841 - val_acc: 0.5172
Epoch 12/100
95/95 [==============================] - 0s - loss: 1.7781 - acc: 0.6842 - val_loss: 0.9417 - val_acc: 0.7241
Epoch 13/100
95/95 [==============================] - 0s - loss: 0.7498 - acc: 0.8000 - val_loss: 4.3745 - val_acc: 0.5862
Epoch 14/100
95/95 [==============================] - 0s - loss: 2.2542 - acc: 0.7053 - val_loss: 2.5444 - val_acc: 0.7241
Epoch 15/100
95/95 [==============================] - 0s - loss: 1.7370 - acc: 0.7474 - val_loss: 1.3782 - val_acc: 0.5862
Epoch 16/100
95/95 [==============================] - 0s - loss: 1.9214 - acc: 0.6737 - val_loss: 1.6002 - val_acc: 0.7931
Epoch 17/100
95/95 [==============================] - 0s - loss: 1.0199 - acc: 0.8000 - val_loss: 2.7927 - val_acc: 0.6897
Epoch 18/100
95/95 [==============================] - 0s - loss: 1.9498 - acc: 0.6947 - val_loss: 2.5903 - val_acc: 0.7241
Epoch 19/100
95/95 [==============================] - 0s - loss: 0.7363 - acc: 0.8211 - val_loss: 4.8215 - val_acc: 0.5862
Epoch 20/100
95/95 [==============================] - 0s - loss: 1.9845 - acc: 0.7684 - val_loss: 1.4873 - val_acc: 0.8276
Epoch 21/100
95/95 [==============================] - 0s - loss: 1.0453 - acc: 0.7789 - val_loss: 1.4767 - val_acc: 0.6207
Epoch 22/100
95/95 [==============================] - 0s - loss: 1.9646 - acc: 0.7053 - val_loss: 1.2845 - val_acc: 0.7931
Epoch 23/100
95/95 [==============================] - 0s - loss: 0.8001 - acc: 0.7579 - val_loss: 2.4354 - val_acc: 0.5172
Epoch 24/100
95/95 [==============================] - 0s - loss: 2.0241 - acc: 0.6632 - val_loss: 4.5978 - val_acc: 0.5172
Epoch 25/100
95/95 [==============================] - 0s - loss: 1.0978 - acc: 0.8421 - val_loss: 1.2032 - val_acc: 0.7586
Epoch 26/100
95/95 [==============================] - 0s - loss: 0.9482 - acc: 0.8000 - val_loss: 1.1308 - val_acc: 0.7241
Epoch 27/100
95/95 [==============================] - 0s - loss: 0.2699 - acc: 0.9158 - val_loss: 2.0574 - val_acc: 0.7931
Epoch 28/100
95/95 [==============================] - 0s - loss: 0.2321 - acc: 0.9263 - val_loss: 1.7931 - val_acc: 0.7931
Epoch 29/100
95/95 [==============================] - 0s - loss: 1.9994 - acc: 0.6632 - val_loss: 1.0347 - val_acc: 0.7241
Epoch 30/100
95/95 [==============================] - 0s - loss: 0.6135 - acc: 0.8632 - val_loss: 4.3345 - val_acc: 0.5862
Epoch 31/100
95/95 [==============================] - 0s - loss: 0.6650 - acc: 0.9053 - val_loss: 1.1803 - val_acc: 0.6207
Epoch 32/100
95/95 [==============================] - 0s - loss: 0.2287 - acc: 0.9368 - val_loss: 1.8179 - val_acc: 0.5862
Epoch 33/100
95/95 [==============================] - 0s - loss: 0.5601 - acc: 0.8316 - val_loss: 0.9933 - val_acc: 0.7586
Epoch 34/100
95/95 [==============================] - 0s - loss: 0.0615 - acc: 0.9474 - val_loss: 1.5226 - val_acc: 0.7931
Epoch 35/100
95/95 [==============================] - 0s - loss: 0.0670 - acc: 0.9579 - val_loss: 0.8354 - val_acc: 0.7586
Epoch 36/100
95/95 [==============================] - 0s - loss: 0.8917 - acc: 0.8211 - val_loss: 2.3119 - val_acc: 0.7586
Epoch 37/100
95/95 [==============================] - 0s - loss: 1.2871 - acc: 0.7789 - val_loss: 1.0523 - val_acc: 0.7586
Epoch 38/100
95/95 [==============================] - 0s - loss: 0.0232 - acc: 1.0000 - val_loss: 1.2485 - val_acc: 0.7931
Epoch 39/100
95/95 [==============================] - 0s - loss: 0.0453 - acc: 0.9684 - val_loss: 0.8709 - val_acc: 0.6897
Epoch 40/100
95/95 [==============================] - 0s - loss: 0.0285 - acc: 0.9895 - val_loss: 0.8730 - val_acc: 0.7931
Epoch 41/100
95/95 [==============================] - 0s - loss: 0.0106 - acc: 1.0000 - val_loss: 0.9136 - val_acc: 0.7586
Epoch 42/100
95/95 [==============================] - 0s - loss: 0.0098 - acc: 1.0000 - val_loss: 0.9271 - val_acc: 0.7586
Epoch 43/100
95/95 [==============================] - 0s - loss: 0.0094 - acc: 1.0000 - val_loss: 0.8433 - val_acc: 0.7931
Epoch 44/100
95/95 [==============================] - 0s - loss: 0.0095 - acc: 1.0000 - val_loss: 0.8829 - val_acc: 0.7586
Epoch 45/100
95/95 [==============================] - 0s - loss: 0.0089 - acc: 1.0000 - val_loss: 0.8444 - val_acc: 0.7931
Epoch 46/100
95/95 [==============================] - 0s - loss: 0.0082 - acc: 1.0000 - val_loss: 0.8387 - val_acc: 0.7931
Epoch 47/100
95/95 [==============================] - 0s - loss: 0.0079 - acc: 1.0000 - val_loss: 0.8016 - val_acc: 0.7931
Epoch 48/100
95/95 [==============================] - 0s - loss: 0.0078 - acc: 1.0000 - val_loss: 0.8604 - val_acc: 0.7586
Epoch 49/100
95/95 [==============================] - 0s - loss: 0.0069 - acc: 1.0000 - val_loss: 0.8115 - val_acc: 0.7931
Epoch 50/100
95/95 [==============================] - 0s - loss: 0.0087 - acc: 1.0000 - val_loss: 0.7756 - val_acc: 0.7931
Epoch 51/100
95/95 [==============================] - 0s - loss: 0.0071 - acc: 1.0000 - val_loss: 0.7941 - val_acc: 0.7931
Epoch 52/100
95/95 [==============================] - 0s - loss: 0.0067 - acc: 1.0000 - val_loss: 0.7875 - val_acc: 0.7931
Epoch 53/100
95/95 [==============================] - 0s - loss: 0.0063 - acc: 1.0000 - val_loss: 0.9167 - val_acc: 0.7586
Epoch 54/100
95/95 [==============================] - 0s - loss: 0.0120 - acc: 1.0000 - val_loss: 0.8457 - val_acc: 0.7586
Epoch 55/100
95/95 [==============================] - 0s - loss: 0.0066 - acc: 1.0000 - val_loss: 0.8763 - val_acc: 0.7586
Epoch 56/100
95/95 [==============================] - 0s - loss: 0.0071 - acc: 1.0000 - val_loss: 0.9077 - val_acc: 0.7586
Epoch 57/100
95/95 [==============================] - 0s - loss: 0.0066 - acc: 1.0000 - val_loss: 0.8614 - val_acc: 0.7586
Epoch 58/100
95/95 [==============================] - 0s - loss: 1.9226 - acc: 0.7158 - val_loss: 4.0855 - val_acc: 0.5862
Epoch 59/100
95/95 [==============================] - 0s - loss: 1.9000 - acc: 0.7263 - val_loss: 1.1361 - val_acc: 0.7931
Epoch 60/100
95/95 [==============================] - 0s - loss: 0.0170 - acc: 1.0000 - val_loss: 0.8798 - val_acc: 0.7241
Epoch 61/100
95/95 [==============================] - 0s - loss: 0.0128 - acc: 1.0000 - val_loss: 1.1915 - val_acc: 0.7931
Epoch 62/100
95/95 [==============================] - 0s - loss: 0.0113 - acc: 1.0000 - val_loss: 0.9613 - val_acc: 0.7586
Epoch 63/100
95/95 [==============================] - 0s - loss: 0.0073 - acc: 1.0000 - val_loss: 1.0751 - val_acc: 0.7586
Epoch 64/100
95/95 [==============================] - 0s - loss: 0.0076 - acc: 1.0000 - val_loss: 0.9223 - val_acc: 0.7586
Epoch 65/100
95/95 [==============================] - 0s - loss: 0.0079 - acc: 1.0000 - val_loss: 0.8807 - val_acc: 0.7586
Epoch 66/100
95/95 [==============================] - 0s - loss: 0.0065 - acc: 1.0000 - val_loss: 0.9552 - val_acc: 0.7931
Epoch 67/100
95/95 [==============================] - 0s - loss: 0.0055 - acc: 1.0000 - val_loss: 0.9431 - val_acc: 0.7931
Epoch 68/100
95/95 [==============================] - 0s - loss: 0.0057 - acc: 1.0000 - val_loss: 0.9436 - val_acc: 0.7931
Epoch 69/100
95/95 [==============================] - 0s - loss: 0.0053 - acc: 1.0000 - val_loss: 0.9743 - val_acc: 0.7586
Epoch 70/100
95/95 [==============================] - 0s - loss: 0.0047 - acc: 1.0000 - val_loss: 0.9179 - val_acc: 0.7931
Epoch 71/100
95/95 [==============================] - 0s - loss: 0.0054 - acc: 1.0000 - val_loss: 0.9302 - val_acc: 0.7586
Epoch 72/100
95/95 [==============================] - 0s - loss: 0.0043 - acc: 1.0000 - val_loss: 1.0118 - val_acc: 0.7586
Epoch 73/100
95/95 [==============================] - 0s - loss: 0.0048 - acc: 1.0000 - val_loss: 0.9239 - val_acc: 0.7586
Epoch 74/100
95/95 [==============================] - 0s - loss: 0.0041 - acc: 1.0000 - val_loss: 1.0160 - val_acc: 0.7586
Epoch 75/100
95/95 [==============================] - 0s - loss: 0.0039 - acc: 1.0000 - val_loss: 0.9258 - val_acc: 0.7586
Epoch 76/100
95/95 [==============================] - 0s - loss: 0.0039 - acc: 1.0000 - val_loss: 0.9625 - val_acc: 0.7586
Epoch 77/100
95/95 [==============================] - 0s - loss: 0.0035 - acc: 1.0000 - val_loss: 1.0869 - val_acc: 0.7931
Epoch 78/100
95/95 [==============================] - 0s - loss: 0.0038 - acc: 1.0000 - val_loss: 0.8638 - val_acc: 0.7931
Epoch 79/100
95/95 [==============================] - 0s - loss: 0.0040 - acc: 1.0000 - val_loss: 0.8562 - val_acc: 0.7931
Epoch 80/100
95/95 [==============================] - 0s - loss: 0.0036 - acc: 1.0000 - val_loss: 1.1895 - val_acc: 0.7931
Epoch 81/100
95/95 [==============================] - 0s - loss: 1.3726 - acc: 0.7895 - val_loss: 5.7269 - val_acc: 0.5862
Epoch 82/100
95/95 [==============================] - 0s - loss: 2.9303 - acc: 0.7158 - val_loss: 1.8757 - val_acc: 0.7586
Epoch 83/100
95/95 [==============================] - 0s - loss: 1.4360 - acc: 0.7895 - val_loss: 2.1063 - val_acc: 0.6207
Epoch 84/100
95/95 [==============================] - 0s - loss: 0.0835 - acc: 0.9579 - val_loss: 1.4317 - val_acc: 0.7586
Epoch 85/100
95/95 [==============================] - 0s - loss: 0.1415 - acc: 0.9684 - val_loss: 1.3718 - val_acc: 0.6897
Epoch 86/100
95/95 [==============================] - 0s - loss: 0.2231 - acc: 0.9684 - val_loss: 6.5216 - val_acc: 0.4483
Epoch 87/100
95/95 [==============================] - 0s - loss: 5.3983 - acc: 0.5053 - val_loss: 3.4698 - val_acc: 0.7241
Epoch 88/100
95/95 [==============================] - 0s - loss: 2.1108 - acc: 0.7474 - val_loss: 1.4435 - val_acc: 0.6897
Epoch 89/100
95/95 [==============================] - 0s - loss: 0.0242 - acc: 0.9895 - val_loss: 1.3336 - val_acc: 0.6897
Epoch 90/100
95/95 [==============================] - 0s - loss: 0.0113 - acc: 1.0000 - val_loss: 1.3468 - val_acc: 0.6897
Epoch 91/100
95/95 [==============================] - 0s - loss: 0.0085 - acc: 1.0000 - val_loss: 1.3176 - val_acc: 0.6897
Epoch 92/100
95/95 [==============================] - 0s - loss: 0.0064 - acc: 1.0000 - val_loss: 1.3492 - val_acc: 0.7241
Epoch 93/100
95/95 [==============================] - 0s - loss: 0.0048 - acc: 1.0000 - val_loss: 1.3506 - val_acc: 0.7241
Epoch 94/100
95/95 [==============================] - 0s - loss: 0.0060 - acc: 1.0000 - val_loss: 1.3145 - val_acc: 0.7241
Epoch 95/100
95/95 [==============================] - 0s - loss: 0.0037 - acc: 1.0000 - val_loss: 1.2798 - val_acc: 0.6897
Epoch 96/100
95/95 [==============================] - 0s - loss: 0.0037 - acc: 1.0000 - val_loss: 1.2985 - val_acc: 0.7241
Epoch 97/100
95/95 [==============================] - 0s - loss: 0.0032 - acc: 1.0000 - val_loss: 1.3054 - val_acc: 0.7241
Epoch 98/100
95/95 [==============================] - 0s - loss: 0.0030 - acc: 1.0000 - val_loss: 1.3084 - val_acc: 0.7586
Epoch 99/100
95/95 [==============================] - 0s - loss: 0.0028 - acc: 1.0000 - val_loss: 1.2935 - val_acc: 0.7586
Epoch 100/100
95/95 [==============================] - 0s - loss: 0.0028 - acc: 1.0000 - val_loss: 1.3065 - val_acc: 0.7586
In [67]:
plot_training_curves(history.history);
In [68]:
score = model.evaluate(X_test_gray, Y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
('Test loss:', 1.5162569284439087)
('Test accuracy:', 0.72413790225982666)
In [69]:
predict_train = model.predict(X_train_gray)
predict_valid = model.predict(X_valid_gray)
predict_test = model.predict(X_test_gray)

Surprise! It's hugely overfitting, with 50K parameters and a couple of hundred images. Not very useful.

We'll see if a simple conv net will do a bit better. Then we can try data augmentation, regularization, and perhaps making our problem simpler by using the bounding boxes. First though, let's put together a bit more infrastructure:

  • looking at training and test examples, right and wrong
  • let's set up tensorboard so we can look at weight distributions
In [37]:
# Look at training data -- there's so little we can look at all of it

def plot_data(xs, ys, predicts):
    """Plot the images in xs, with corresponding correct labels
    and predictions.
    
    Args:
        xs: RGB or grayscale images with float32 values in [0,1].
        ys: one-hot encoded labels
        predicts: probability vectors (same dim as ys, normalized e.g. via softmax)
    """
    
    # sort all 3 by ys
    xs, ys, ps = zip(*sorted(zip(xs, ys, predicts), 
                             key=lambda tpl: tpl[1][0]))
    n = len(xs)
    rows = (n+9)/10
    fig, plots = plt.subplots(rows,10, sharex='all', sharey='all',
                             figsize=(20,2*rows), squeeze=False)
    for i in range(n):
        # read the image
        ax = plots[i // 10, i % 10]
        ax.axis('off')
        img = xs[i].reshape(227,227,-1) 

        if img.shape[-1] == 1: # Grayscale
            # Get rid of the unneeded dimension
            img = img.squeeze()
            # flip grayscale:
            img = 1-img 
            
        ax.imshow(img)
        # dot with one-hot vector picks out right element
        pcorrect = np.dot(ps[i], ys[i]) 
        if pcorrect > 0.8:
            color = "blue"
        else:
            color = "red"
        ax.set_title("{}   p={:.2f}".format(int(ys[i][0]), pcorrect),
                     loc='center', fontsize=18, color=color)
    return fig
In [71]:
fig = plot_data(X_train_gray, Y_train, predict_train)
fig.suptitle("Train")

fig = plot_data(X_valid_gray, Y_valid, predict_valid)
fig.suptitle("Valid")

fig = plot_data(X_test_gray, Y_test, predict_test)
fig.suptitle("Test")
Out[71]:
<matplotlib.text.Text at 0x148e647d0>

Looking at this really highlights how little data we have. Probably not worth using 40% for validation and test. Cross-validation would be better, and with this little data, would probably be fast enough. Can definitely benefit from data augmentation too.

On the other hand, looks like the images are pretty well centered -- bounding boxes might help a bit, but there are very few gross scale differences.

While we're building infrastructure, let's change above plots to include whether we got each example right and wrong. [done]

Now, let's try a basic CNN

We'll use three conv layers, then fully connected one. We'll use dropout to try to fight overfitting...

In [38]:
# normalize the data, this time leaving it in color
def normalize_for_cnn(xs):
    ret = (xs / 255.0)
    return ret
In [78]:
X_train_norm = normalize_for_cnn(X_train)
X_valid_norm = normalize_for_cnn(X_valid)
X_test_norm = normalize_for_cnn(X_test)
In [79]:
X_train_norm.shape
Out[79]:
(95, 227, 227, 3)
In [41]:
def cnn_model(use_dropout=True):
    model = Sequential()
    nb_filters = 16
    pool_size = (2,2)
    filter_size = 3
    nb_classes = 2
    
    with tf.name_scope("conv1") as scope:
        model.add(Convolution2D(nb_filters, filter_size,
                            input_shape=(227, 227, 3)))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=pool_size))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("conv2") as scope:
        model.add(Convolution2D(nb_filters, filter_size))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=pool_size))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("conv3") as scope:
        model.add(Convolution2D(nb_filters, filter_size))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=pool_size))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("dense1") as scope:
        model.add(Flatten())
        model.add(Dense(16))
        model.add(Activation('relu'))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("softmax") as scope:
        model.add(Dense(nb_classes))
        model.add(Activation('softmax'))
    return model

# Uncomment if getting a "Invalid argument: You must feed a value
# for placeholder tensor ..." when rerunning training. 
# K.clear_session() # https://github.com/fchollet/keras/issues/4499
    

model2 = cnn_model()
model2.compile(loss='categorical_crossentropy',
              optimizer='adadelta',
              metrics=['accuracy'])
In [42]:
print(model2.summary())
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_4 (Conv2D)            (None, 225, 225, 16)      448       
_________________________________________________________________
activation_6 (Activation)    (None, 225, 225, 16)      0         
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 112, 112, 16)      0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 112, 112, 16)      0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 110, 110, 16)      2320      
_________________________________________________________________
activation_7 (Activation)    (None, 110, 110, 16)      0         
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 55, 55, 16)        0         
_________________________________________________________________
dropout_6 (Dropout)          (None, 55, 55, 16)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 53, 53, 16)        2320      
_________________________________________________________________
activation_8 (Activation)    (None, 53, 53, 16)        0         
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 26, 26, 16)        0         
_________________________________________________________________
dropout_7 (Dropout)          (None, 26, 26, 16)        0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 10816)             0         
_________________________________________________________________
dense_3 (Dense)              (None, 16)                173072    
_________________________________________________________________
activation_9 (Activation)    (None, 16)                0         
_________________________________________________________________
dropout_8 (Dropout)          (None, 16)                0         
_________________________________________________________________
dense_4 (Dense)              (None, 2)                 34        
_________________________________________________________________
activation_10 (Activation)   (None, 2)                 0         
=================================================================
Total params: 178,194.0
Trainable params: 178,194.0
Non-trainable params: 0.0
_________________________________________________________________
None
In [130]:
recompute = False

if recompute:
    
    # Save info during computation so we can see what's happening
    tbCallback = TensorBoard(
        log_dir='./graph', histogram_freq=1, 
        write_graph=False, write_images=False)
    
    # Fit the model!
    history = model2.fit(
        X_train_norm, Y_train,
        batch_size=16, nb_epoch=100, verbose=1,
        validation_data=(X_valid_norm, Y_valid),
        callbacks=[tbCallback]
    )
else:
    model2.load_weights('hummer_acura_simple_cnn.h5')
Train on 95 samples, validate on 29 samples
Epoch 1/100
95/95 [==============================] - 5s - loss: 0.6914 - acc: 0.5579 - val_loss: 0.6880 - val_acc: 0.5862
Epoch 2/100
95/95 [==============================] - 5s - loss: 0.6856 - acc: 0.5895 - val_loss: 0.6893 - val_acc: 0.6552
Epoch 3/100
95/95 [==============================] - 5s - loss: 0.6683 - acc: 0.5368 - val_loss: 0.6863 - val_acc: 0.6552
Epoch 4/100
95/95 [==============================] - 5s - loss: 0.6782 - acc: 0.5789 - val_loss: 0.6887 - val_acc: 0.6552
Epoch 5/100
95/95 [==============================] - 4s - loss: 0.6738 - acc: 0.5579 - val_loss: 0.6829 - val_acc: 0.6207
Epoch 6/100
95/95 [==============================] - 5s - loss: 0.6751 - acc: 0.5684 - val_loss: 0.6877 - val_acc: 0.6552
Epoch 7/100
95/95 [==============================] - 5s - loss: 0.6600 - acc: 0.6316 - val_loss: 0.6863 - val_acc: 0.7241
Epoch 8/100
95/95 [==============================] - 5s - loss: 0.6572 - acc: 0.6737 - val_loss: 0.6870 - val_acc: 0.6552
Epoch 9/100
95/95 [==============================] - 5s - loss: 0.6532 - acc: 0.6421 - val_loss: 0.6859 - val_acc: 0.7241
Epoch 10/100
95/95 [==============================] - 5s - loss: 0.6501 - acc: 0.6211 - val_loss: 0.6844 - val_acc: 0.7241
Epoch 11/100
95/95 [==============================] - 5s - loss: 0.6378 - acc: 0.6632 - val_loss: 0.6788 - val_acc: 0.7241
Epoch 12/100
95/95 [==============================] - 5s - loss: 0.6235 - acc: 0.6211 - val_loss: 0.6696 - val_acc: 0.7241
Epoch 13/100
95/95 [==============================] - 5s - loss: 0.6174 - acc: 0.7053 - val_loss: 0.6650 - val_acc: 0.6552
Epoch 14/100
95/95 [==============================] - 5s - loss: 0.5937 - acc: 0.6947 - val_loss: 0.6585 - val_acc: 0.7241
Epoch 15/100
95/95 [==============================] - 5s - loss: 0.5211 - acc: 0.7579 - val_loss: 0.6433 - val_acc: 0.7241
Epoch 16/100
95/95 [==============================] - 5s - loss: 0.6316 - acc: 0.6632 - val_loss: 0.6380 - val_acc: 0.7241
Epoch 17/100
95/95 [==============================] - 5s - loss: 0.5487 - acc: 0.7579 - val_loss: 0.6202 - val_acc: 0.7241
Epoch 18/100
95/95 [==============================] - 5s - loss: 0.5338 - acc: 0.8211 - val_loss: 0.6368 - val_acc: 0.7586
Epoch 19/100
95/95 [==============================] - 5s - loss: 0.5266 - acc: 0.7368 - val_loss: 0.5969 - val_acc: 0.7931
Epoch 20/100
95/95 [==============================] - 5s - loss: 0.5201 - acc: 0.6526 - val_loss: 0.6188 - val_acc: 0.7241
Epoch 21/100
95/95 [==============================] - 5s - loss: 0.5393 - acc: 0.7368 - val_loss: 0.6296 - val_acc: 0.7931
Epoch 22/100
95/95 [==============================] - 5s - loss: 0.5538 - acc: 0.7789 - val_loss: 0.6039 - val_acc: 0.7931
Epoch 23/100
95/95 [==============================] - 5s - loss: 0.4918 - acc: 0.7579 - val_loss: 0.5884 - val_acc: 0.7931
Epoch 24/100
95/95 [==============================] - 5s - loss: 0.4972 - acc: 0.7684 - val_loss: 0.5884 - val_acc: 0.8276
Epoch 25/100
95/95 [==============================] - 5s - loss: 0.4925 - acc: 0.7474 - val_loss: 0.5530 - val_acc: 0.7931
Epoch 26/100
95/95 [==============================] - 5s - loss: 0.4328 - acc: 0.8316 - val_loss: 0.5435 - val_acc: 0.7931
Epoch 27/100
95/95 [==============================] - 5s - loss: 0.4626 - acc: 0.7895 - val_loss: 0.5321 - val_acc: 0.7931
Epoch 28/100
95/95 [==============================] - 5s - loss: 0.4393 - acc: 0.8105 - val_loss: 0.5390 - val_acc: 0.7931
Epoch 29/100
95/95 [==============================] - 5s - loss: 0.4124 - acc: 0.8105 - val_loss: 0.4984 - val_acc: 0.7241
Epoch 30/100
95/95 [==============================] - 5s - loss: 0.4338 - acc: 0.8105 - val_loss: 0.4776 - val_acc: 0.7586
Epoch 31/100
95/95 [==============================] - 5s - loss: 0.3461 - acc: 0.8632 - val_loss: 0.4598 - val_acc: 0.7241
Epoch 32/100
95/95 [==============================] - 5s - loss: 0.3778 - acc: 0.8316 - val_loss: 0.5046 - val_acc: 0.7241
Epoch 33/100
95/95 [==============================] - 5s - loss: 0.3717 - acc: 0.8105 - val_loss: 0.4466 - val_acc: 0.7931
Epoch 34/100
95/95 [==============================] - 5s - loss: 0.4031 - acc: 0.8211 - val_loss: 0.4645 - val_acc: 0.7241
Epoch 35/100
95/95 [==============================] - 5s - loss: 0.3973 - acc: 0.7579 - val_loss: 0.4675 - val_acc: 0.7931
Epoch 36/100
95/95 [==============================] - 5s - loss: 0.3661 - acc: 0.8421 - val_loss: 0.4373 - val_acc: 0.7241
Epoch 37/100
95/95 [==============================] - 5s - loss: 0.3475 - acc: 0.8105 - val_loss: 0.4501 - val_acc: 0.6897
Epoch 38/100
95/95 [==============================] - 5s - loss: 0.3418 - acc: 0.8632 - val_loss: 0.4275 - val_acc: 0.6897
Epoch 39/100
95/95 [==============================] - 5s - loss: 0.4382 - acc: 0.8421 - val_loss: 0.4451 - val_acc: 0.7586
Epoch 40/100
95/95 [==============================] - 5s - loss: 0.3511 - acc: 0.8421 - val_loss: 0.4312 - val_acc: 0.7586
Epoch 41/100
95/95 [==============================] - 5s - loss: 0.3105 - acc: 0.8947 - val_loss: 0.4691 - val_acc: 0.7586
Epoch 42/100
95/95 [==============================] - 5s - loss: 0.4576 - acc: 0.8105 - val_loss: 0.4655 - val_acc: 0.7931
Epoch 43/100
95/95 [==============================] - 5s - loss: 0.2975 - acc: 0.9158 - val_loss: 0.4190 - val_acc: 0.7241
Epoch 44/100
95/95 [==============================] - 5s - loss: 0.3235 - acc: 0.8632 - val_loss: 0.4226 - val_acc: 0.8276
Epoch 45/100
95/95 [==============================] - 5s - loss: 0.2526 - acc: 0.9368 - val_loss: 0.4189 - val_acc: 0.7241
Epoch 46/100
95/95 [==============================] - 5s - loss: 0.4299 - acc: 0.8421 - val_loss: 0.4129 - val_acc: 0.7931
Epoch 47/100
95/95 [==============================] - 5s - loss: 0.2710 - acc: 0.9053 - val_loss: 0.4112 - val_acc: 0.7241
Epoch 48/100
95/95 [==============================] - 5s - loss: 0.3607 - acc: 0.8632 - val_loss: 0.4287 - val_acc: 0.7931
Epoch 49/100
95/95 [==============================] - 5s - loss: 0.2786 - acc: 0.8632 - val_loss: 0.4321 - val_acc: 0.8276
Epoch 50/100
95/95 [==============================] - 5s - loss: 0.2770 - acc: 0.8737 - val_loss: 0.4112 - val_acc: 0.8276
Epoch 51/100
95/95 [==============================] - 5s - loss: 0.2464 - acc: 0.9263 - val_loss: 0.4098 - val_acc: 0.8276
Epoch 52/100
95/95 [==============================] - 5s - loss: 0.1824 - acc: 0.9474 - val_loss: 0.4030 - val_acc: 0.7586
Epoch 53/100
95/95 [==============================] - 5s - loss: 0.2632 - acc: 0.9158 - val_loss: 0.4148 - val_acc: 0.8276
Epoch 54/100
95/95 [==============================] - 5s - loss: 0.2895 - acc: 0.8421 - val_loss: 0.3866 - val_acc: 0.8276
Epoch 55/100
95/95 [==============================] - 5s - loss: 0.2186 - acc: 0.9158 - val_loss: 0.3817 - val_acc: 0.8276
Epoch 56/100
95/95 [==============================] - 5s - loss: 0.2248 - acc: 0.9368 - val_loss: 0.4008 - val_acc: 0.8276
Epoch 57/100
95/95 [==============================] - 5s - loss: 0.1989 - acc: 0.9053 - val_loss: 0.3888 - val_acc: 0.8276
Epoch 58/100
95/95 [==============================] - 5s - loss: 0.1902 - acc: 0.9158 - val_loss: 0.3851 - val_acc: 0.8276
Epoch 59/100
95/95 [==============================] - 5s - loss: 0.1485 - acc: 0.9684 - val_loss: 0.3854 - val_acc: 0.8276
Epoch 60/100
95/95 [==============================] - 5s - loss: 0.1657 - acc: 0.9263 - val_loss: 0.3963 - val_acc: 0.7241
Epoch 61/100
95/95 [==============================] - 5s - loss: 0.2380 - acc: 0.9158 - val_loss: 0.3923 - val_acc: 0.8276
Epoch 62/100
95/95 [==============================] - 5s - loss: 0.1907 - acc: 0.9263 - val_loss: 0.4024 - val_acc: 0.8276
Epoch 63/100
95/95 [==============================] - 5s - loss: 0.1586 - acc: 0.9579 - val_loss: 0.4029 - val_acc: 0.8276
Epoch 64/100
95/95 [==============================] - 5s - loss: 0.2352 - acc: 0.9368 - val_loss: 0.3586 - val_acc: 0.8276
Epoch 65/100
95/95 [==============================] - 5s - loss: 0.2612 - acc: 0.9368 - val_loss: 0.4281 - val_acc: 0.7586
Epoch 66/100
95/95 [==============================] - 5s - loss: 0.2030 - acc: 0.9158 - val_loss: 0.3955 - val_acc: 0.8621
Epoch 67/100
95/95 [==============================] - 5s - loss: 0.1239 - acc: 0.9579 - val_loss: 0.4857 - val_acc: 0.8276
Epoch 68/100
95/95 [==============================] - 5s - loss: 0.2595 - acc: 0.9263 - val_loss: 0.3724 - val_acc: 0.8276
Epoch 69/100
95/95 [==============================] - 5s - loss: 0.1228 - acc: 0.9579 - val_loss: 0.3711 - val_acc: 0.8621
Epoch 70/100
95/95 [==============================] - 5s - loss: 0.2284 - acc: 0.9158 - val_loss: 0.3500 - val_acc: 0.8276
Epoch 71/100
95/95 [==============================] - 5s - loss: 0.1793 - acc: 0.9684 - val_loss: 0.3408 - val_acc: 0.8621
Epoch 72/100
95/95 [==============================] - 5s - loss: 0.1197 - acc: 0.9474 - val_loss: 0.3261 - val_acc: 0.7586
Epoch 73/100
95/95 [==============================] - 5s - loss: 0.1587 - acc: 0.9158 - val_loss: 0.3317 - val_acc: 0.8276
Epoch 74/100
95/95 [==============================] - 5s - loss: 0.1559 - acc: 0.9474 - val_loss: 0.3611 - val_acc: 0.8621
Epoch 75/100
95/95 [==============================] - 5s - loss: 0.1357 - acc: 0.9368 - val_loss: 0.3532 - val_acc: 0.7931
Epoch 76/100
95/95 [==============================] - 5s - loss: 0.1473 - acc: 0.9368 - val_loss: 0.3702 - val_acc: 0.8621
Epoch 77/100
95/95 [==============================] - 5s - loss: 0.1008 - acc: 0.9789 - val_loss: 0.4450 - val_acc: 0.8621
Epoch 78/100
95/95 [==============================] - 5s - loss: 0.1412 - acc: 0.9368 - val_loss: 0.4363 - val_acc: 0.8276
Epoch 79/100
95/95 [==============================] - 5s - loss: 0.1169 - acc: 0.9579 - val_loss: 0.3704 - val_acc: 0.7931
Epoch 80/100
95/95 [==============================] - 5s - loss: 0.0959 - acc: 0.9895 - val_loss: 0.3623 - val_acc: 0.7586
Epoch 81/100
95/95 [==============================] - 5s - loss: 0.1971 - acc: 0.9158 - val_loss: 0.3684 - val_acc: 0.7586
Epoch 82/100
95/95 [==============================] - 5s - loss: 0.0978 - acc: 0.9895 - val_loss: 0.4040 - val_acc: 0.8621
Epoch 83/100
95/95 [==============================] - 5s - loss: 0.0891 - acc: 0.9684 - val_loss: 0.4027 - val_acc: 0.7931
Epoch 84/100
95/95 [==============================] - 5s - loss: 0.1411 - acc: 0.9684 - val_loss: 0.3749 - val_acc: 0.8276
Epoch 85/100
95/95 [==============================] - 5s - loss: 0.1179 - acc: 0.9474 - val_loss: 0.3673 - val_acc: 0.8621
Epoch 86/100
95/95 [==============================] - 5s - loss: 0.1044 - acc: 0.9474 - val_loss: 0.3784 - val_acc: 0.8276
Epoch 87/100
95/95 [==============================] - 5s - loss: 0.1011 - acc: 0.9474 - val_loss: 0.3760 - val_acc: 0.8276
Epoch 88/100
95/95 [==============================] - 5s - loss: 0.1604 - acc: 0.9263 - val_loss: 0.3580 - val_acc: 0.7586
Epoch 89/100
95/95 [==============================] - 5s - loss: 0.0982 - acc: 0.9684 - val_loss: 0.4037 - val_acc: 0.8621
Epoch 90/100
95/95 [==============================] - 5s - loss: 0.1496 - acc: 0.9158 - val_loss: 0.4121 - val_acc: 0.8621
Epoch 91/100
95/95 [==============================] - 5s - loss: 0.0849 - acc: 0.9684 - val_loss: 0.3982 - val_acc: 0.8621
Epoch 92/100
95/95 [==============================] - 5s - loss: 0.1234 - acc: 0.9789 - val_loss: 0.4126 - val_acc: 0.7931
Epoch 93/100
95/95 [==============================] - 5s - loss: 0.0964 - acc: 0.9789 - val_loss: 0.4113 - val_acc: 0.8621
Epoch 94/100
95/95 [==============================] - 5s - loss: 0.1479 - acc: 0.9368 - val_loss: 0.4511 - val_acc: 0.8621
Epoch 95/100
95/95 [==============================] - 5s - loss: 0.1315 - acc: 0.9368 - val_loss: 0.4166 - val_acc: 0.8621
Epoch 96/100
95/95 [==============================] - 5s - loss: 0.0935 - acc: 0.9579 - val_loss: 0.4036 - val_acc: 0.8276
Epoch 97/100
95/95 [==============================] - 5s - loss: 0.0777 - acc: 0.9684 - val_loss: 0.4635 - val_acc: 0.8621
Epoch 98/100
95/95 [==============================] - 5s - loss: 0.1072 - acc: 0.9684 - val_loss: 0.4007 - val_acc: 0.8621
Epoch 99/100
95/95 [==============================] - 5s - loss: 0.0849 - acc: 0.9789 - val_loss: 0.4278 - val_acc: 0.8621
Epoch 100/100
95/95 [==============================] - 5s - loss: 0.1049 - acc: 0.9579 - val_loss: 0.4431 - val_acc: 0.8276
In [326]:
model2.save('hummer_acura_simple_cnn.h5')
In [131]:
plot_training_curves(history.history);
In [132]:
score = model2.evaluate(X_test_norm, Y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
('Test loss:', 0.23822678625583649)
('Test accuracy:', 0.93103450536727905)

Hmm. Seems to work remarkably well, given how little data we have! Yay for dropout. We can move on to a more interesting problem, but first, let's look at a few things:

  • take a look at the network in a bit more detail
  • try it on some hummer, acura, and other photos from the internet.
  • See how much dropout is helping...

First, let's take a look at the results:

In [84]:
predict_train = model2.predict(X_train_norm)
predict_valid = model2.predict(X_valid_norm)
predict_test = model2.predict(X_test_norm)

fig = plot_data(X_train_norm, Y_train, predict_train)
fig.suptitle("Train")

fig = plot_data(X_valid_norm, Y_valid, predict_valid)
fig.suptitle("Valid")

fig = plot_data(X_test_norm, Y_test, predict_test)
fig.suptitle("Test")
Out[84]:
<matplotlib.text.Text at 0x1407a0750>

Ok, now let's see how the same model does without dropout. Given how little data we have, I'd expect it to do a lot worse...

In [92]:
model2_nodropout = cnn_model(use_dropout=False)
model2_nodropout.compile(loss='categorical_crossentropy',
              optimizer='adadelta',
              metrics=['accuracy'])
In [93]:
history = model2_nodropout.fit(X_train_norm, Y_train, 
                               batch_size=16, nb_epoch=100, verbose=1,
         validation_data=(X_valid_norm, Y_valid))
Train on 95 samples, validate on 29 samples
Epoch 1/100
95/95 [==============================] - 4s - loss: 0.6857 - acc: 0.5368 - val_loss: 0.6680 - val_acc: 0.5862
Epoch 2/100
95/95 [==============================] - 3s - loss: 0.6846 - acc: 0.5263 - val_loss: 0.6674 - val_acc: 0.5862
Epoch 3/100
95/95 [==============================] - 3s - loss: 0.6758 - acc: 0.6000 - val_loss: 0.6609 - val_acc: 0.6207
Epoch 4/100
95/95 [==============================] - 3s - loss: 0.6578 - acc: 0.6000 - val_loss: 0.6456 - val_acc: 0.7931
Epoch 5/100
95/95 [==============================] - 3s - loss: 0.6013 - acc: 0.6947 - val_loss: 0.9653 - val_acc: 0.5862
Epoch 6/100
95/95 [==============================] - 3s - loss: 0.6211 - acc: 0.7684 - val_loss: 0.5513 - val_acc: 0.8276
Epoch 7/100
95/95 [==============================] - 3s - loss: 0.5953 - acc: 0.8211 - val_loss: 0.5341 - val_acc: 0.7586
Epoch 8/100
95/95 [==============================] - 3s - loss: 0.4788 - acc: 0.8211 - val_loss: 0.4413 - val_acc: 0.7586
Epoch 9/100
95/95 [==============================] - 3s - loss: 0.5216 - acc: 0.7789 - val_loss: 0.4678 - val_acc: 0.7931
Epoch 10/100
95/95 [==============================] - 3s - loss: 0.4020 - acc: 0.8316 - val_loss: 0.4183 - val_acc: 0.7931
Epoch 11/100
95/95 [==============================] - 3s - loss: 0.3046 - acc: 0.8737 - val_loss: 0.4329 - val_acc: 0.7931
Epoch 12/100
95/95 [==============================] - 3s - loss: 0.2745 - acc: 0.8737 - val_loss: 0.5020 - val_acc: 0.7241
Epoch 13/100
95/95 [==============================] - 3s - loss: 0.2261 - acc: 0.8842 - val_loss: 0.5202 - val_acc: 0.7586
Epoch 14/100
95/95 [==============================] - 3s - loss: 0.4743 - acc: 0.8421 - val_loss: 0.4024 - val_acc: 0.8276
Epoch 15/100
95/95 [==============================] - 3s - loss: 0.1906 - acc: 0.9579 - val_loss: 0.4209 - val_acc: 0.7241
Epoch 16/100
95/95 [==============================] - 3s - loss: 0.1088 - acc: 0.9895 - val_loss: 0.5729 - val_acc: 0.7586
Epoch 17/100
95/95 [==============================] - 3s - loss: 0.1518 - acc: 0.9474 - val_loss: 0.4448 - val_acc: 0.7586
Epoch 18/100
95/95 [==============================] - 3s - loss: 0.0891 - acc: 0.9895 - val_loss: 0.4831 - val_acc: 0.7241
Epoch 19/100
95/95 [==============================] - 3s - loss: 0.0564 - acc: 0.9789 - val_loss: 0.5046 - val_acc: 0.8276
Epoch 20/100
95/95 [==============================] - 3s - loss: 0.0347 - acc: 1.0000 - val_loss: 0.5612 - val_acc: 0.7931
Epoch 21/100
95/95 [==============================] - 3s - loss: 0.0312 - acc: 1.0000 - val_loss: 0.6632 - val_acc: 0.7241
Epoch 22/100
95/95 [==============================] - 3s - loss: 0.0236 - acc: 1.0000 - val_loss: 0.7274 - val_acc: 0.7586
Epoch 23/100
95/95 [==============================] - 3s - loss: 0.0154 - acc: 1.0000 - val_loss: 0.6853 - val_acc: 0.7586
Epoch 24/100
95/95 [==============================] - 3s - loss: 0.0150 - acc: 1.0000 - val_loss: 0.6815 - val_acc: 0.7586
Epoch 25/100
95/95 [==============================] - 3s - loss: 0.0092 - acc: 1.0000 - val_loss: 0.7193 - val_acc: 0.7586
Epoch 26/100
95/95 [==============================] - 3s - loss: 0.0077 - acc: 1.0000 - val_loss: 0.7409 - val_acc: 0.7241
Epoch 27/100
95/95 [==============================] - 3s - loss: 0.0068 - acc: 1.0000 - val_loss: 0.7707 - val_acc: 0.7586
Epoch 28/100
95/95 [==============================] - 3s - loss: 0.0053 - acc: 1.0000 - val_loss: 0.7917 - val_acc: 0.7586
Epoch 29/100
95/95 [==============================] - 3s - loss: 0.0042 - acc: 1.0000 - val_loss: 0.8126 - val_acc: 0.7586
Epoch 30/100
95/95 [==============================] - 3s - loss: 0.0035 - acc: 1.0000 - val_loss: 0.8460 - val_acc: 0.7586
Epoch 31/100
95/95 [==============================] - 3s - loss: 0.0043 - acc: 1.0000 - val_loss: 0.8354 - val_acc: 0.7241
Epoch 32/100
95/95 [==============================] - 3s - loss: 0.0027 - acc: 1.0000 - val_loss: 0.8764 - val_acc: 0.7586
Epoch 33/100
95/95 [==============================] - 3s - loss: 0.0026 - acc: 1.0000 - val_loss: 0.8740 - val_acc: 0.7241
Epoch 34/100
95/95 [==============================] - 3s - loss: 0.0020 - acc: 1.0000 - val_loss: 0.8916 - val_acc: 0.6897
Epoch 35/100
95/95 [==============================] - 3s - loss: 0.0019 - acc: 1.0000 - val_loss: 0.9362 - val_acc: 0.7586
Epoch 36/100
95/95 [==============================] - 3s - loss: 0.0018 - acc: 1.0000 - val_loss: 0.9453 - val_acc: 0.7586
Epoch 37/100
95/95 [==============================] - 3s - loss: 0.0016 - acc: 1.0000 - val_loss: 0.9522 - val_acc: 0.7586
Epoch 38/100
95/95 [==============================] - 3s - loss: 0.0013 - acc: 1.0000 - val_loss: 0.9771 - val_acc: 0.7586
Epoch 39/100
95/95 [==============================] - 3s - loss: 0.0011 - acc: 1.0000 - val_loss: 0.9812 - val_acc: 0.7241
Epoch 40/100
95/95 [==============================] - 3s - loss: 0.0011 - acc: 1.0000 - val_loss: 1.0066 - val_acc: 0.7586
Epoch 41/100
95/95 [==============================] - 3s - loss: 9.5960e-04 - acc: 1.0000 - val_loss: 1.0354 - val_acc: 0.7586
Epoch 42/100
95/95 [==============================] - 3s - loss: 8.5902e-04 - acc: 1.0000 - val_loss: 1.0447 - val_acc: 0.7586
Epoch 43/100
95/95 [==============================] - 3s - loss: 7.0911e-04 - acc: 1.0000 - val_loss: 1.0595 - val_acc: 0.7586
Epoch 44/100
95/95 [==============================] - 3s - loss: 6.6720e-04 - acc: 1.0000 - val_loss: 1.0722 - val_acc: 0.7586
Epoch 45/100
95/95 [==============================] - 3s - loss: 6.3465e-04 - acc: 1.0000 - val_loss: 1.0856 - val_acc: 0.7241
Epoch 46/100
95/95 [==============================] - 3s - loss: 5.8962e-04 - acc: 1.0000 - val_loss: 1.1213 - val_acc: 0.7586
Epoch 47/100
95/95 [==============================] - 3s - loss: 4.4358e-04 - acc: 1.0000 - val_loss: 1.1265 - val_acc: 0.7586
Epoch 48/100
95/95 [==============================] - 3s - loss: 4.5379e-04 - acc: 1.0000 - val_loss: 1.1427 - val_acc: 0.7586
Epoch 49/100
95/95 [==============================] - 3s - loss: 4.6210e-04 - acc: 1.0000 - val_loss: 1.1362 - val_acc: 0.7241
Epoch 50/100
95/95 [==============================] - 3s - loss: 3.4455e-04 - acc: 1.0000 - val_loss: 1.1701 - val_acc: 0.7586
Epoch 51/100
95/95 [==============================] - 3s - loss: 2.7071e-04 - acc: 1.0000 - val_loss: 1.2017 - val_acc: 0.7931
Epoch 52/100
95/95 [==============================] - 3s - loss: 2.4358e-04 - acc: 1.0000 - val_loss: 1.1796 - val_acc: 0.7931
Epoch 53/100
95/95 [==============================] - 3s - loss: 2.3586e-04 - acc: 1.0000 - val_loss: 1.2305 - val_acc: 0.7931
Epoch 54/100
95/95 [==============================] - 3s - loss: 2.0421e-04 - acc: 1.0000 - val_loss: 1.2299 - val_acc: 0.7931
Epoch 55/100
95/95 [==============================] - 3s - loss: 1.9359e-04 - acc: 1.0000 - val_loss: 1.1850 - val_acc: 0.7241
Epoch 56/100
95/95 [==============================] - 3s - loss: 1.7682e-04 - acc: 1.0000 - val_loss: 1.1902 - val_acc: 0.7241
Epoch 57/100
95/95 [==============================] - 3s - loss: 1.6404e-04 - acc: 1.0000 - val_loss: 1.2054 - val_acc: 0.7586
Epoch 58/100
95/95 [==============================] - 3s - loss: 1.3212e-04 - acc: 1.0000 - val_loss: 1.2391 - val_acc: 0.7931
Epoch 59/100
95/95 [==============================] - 3s - loss: 1.1638e-04 - acc: 1.0000 - val_loss: 1.2321 - val_acc: 0.7931
Epoch 60/100
95/95 [==============================] - 3s - loss: 1.0729e-04 - acc: 1.0000 - val_loss: 1.2250 - val_acc: 0.7586
Epoch 61/100
95/95 [==============================] - 3s - loss: 1.1666e-04 - acc: 1.0000 - val_loss: 1.2471 - val_acc: 0.7931
Epoch 62/100
95/95 [==============================] - 3s - loss: 8.7458e-05 - acc: 1.0000 - val_loss: 1.2467 - val_acc: 0.7931
Epoch 63/100
95/95 [==============================] - 3s - loss: 8.2138e-05 - acc: 1.0000 - val_loss: 1.2472 - val_acc: 0.7931
Epoch 64/100
95/95 [==============================] - 3s - loss: 8.1756e-05 - acc: 1.0000 - val_loss: 1.2878 - val_acc: 0.7931
Epoch 65/100
95/95 [==============================] - 3s - loss: 7.0272e-05 - acc: 1.0000 - val_loss: 1.2581 - val_acc: 0.7931
Epoch 66/100
95/95 [==============================] - 3s - loss: 6.7523e-05 - acc: 1.0000 - val_loss: 1.3053 - val_acc: 0.7931
Epoch 67/100
95/95 [==============================] - 3s - loss: 5.9961e-05 - acc: 1.0000 - val_loss: 1.2738 - val_acc: 0.7931
Epoch 68/100
95/95 [==============================] - 3s - loss: 5.7311e-05 - acc: 1.0000 - val_loss: 1.2709 - val_acc: 0.7241
Epoch 69/100
95/95 [==============================] - 3s - loss: 5.8061e-05 - acc: 1.0000 - val_loss: 1.3159 - val_acc: 0.7931
Epoch 70/100
95/95 [==============================] - 3s - loss: 4.8750e-05 - acc: 1.0000 - val_loss: 1.2932 - val_acc: 0.7931
Epoch 71/100
95/95 [==============================] - 3s - loss: 4.7013e-05 - acc: 1.0000 - val_loss: 1.3159 - val_acc: 0.7931
Epoch 72/100
95/95 [==============================] - 3s - loss: 4.3949e-05 - acc: 1.0000 - val_loss: 1.3198 - val_acc: 0.7931
Epoch 73/100
95/95 [==============================] - 3s - loss: 3.9394e-05 - acc: 1.0000 - val_loss: 1.3203 - val_acc: 0.7931
Epoch 74/100
95/95 [==============================] - 3s - loss: 3.7360e-05 - acc: 1.0000 - val_loss: 1.3251 - val_acc: 0.7931
Epoch 75/100
95/95 [==============================] - 3s - loss: 3.5351e-05 - acc: 1.0000 - val_loss: 1.3275 - val_acc: 0.7931
Epoch 76/100
95/95 [==============================] - 3s - loss: 3.3695e-05 - acc: 1.0000 - val_loss: 1.3253 - val_acc: 0.7931
Epoch 77/100
95/95 [==============================] - 3s - loss: 3.1460e-05 - acc: 1.0000 - val_loss: 1.3358 - val_acc: 0.7931
Epoch 78/100
95/95 [==============================] - 3s - loss: 2.9618e-05 - acc: 1.0000 - val_loss: 1.3327 - val_acc: 0.7586
Epoch 79/100
95/95 [==============================] - 3s - loss: 2.9065e-05 - acc: 1.0000 - val_loss: 1.3327 - val_acc: 0.7586
Epoch 80/100
95/95 [==============================] - 3s - loss: 2.7209e-05 - acc: 1.0000 - val_loss: 1.3535 - val_acc: 0.7931
Epoch 81/100
95/95 [==============================] - 3s - loss: 2.5864e-05 - acc: 1.0000 - val_loss: 1.3565 - val_acc: 0.7931
Epoch 82/100
95/95 [==============================] - 3s - loss: 2.4586e-05 - acc: 1.0000 - val_loss: 1.3568 - val_acc: 0.7931
Epoch 83/100
95/95 [==============================] - 3s - loss: 2.3244e-05 - acc: 1.0000 - val_loss: 1.3507 - val_acc: 0.7586
Epoch 84/100
95/95 [==============================] - 3s - loss: 2.2128e-05 - acc: 1.0000 - val_loss: 1.3662 - val_acc: 0.7931
Epoch 85/100
95/95 [==============================] - 3s - loss: 2.1498e-05 - acc: 1.0000 - val_loss: 1.3603 - val_acc: 0.7586
Epoch 86/100
95/95 [==============================] - 3s - loss: 2.0678e-05 - acc: 1.0000 - val_loss: 1.3614 - val_acc: 0.7586
Epoch 87/100
95/95 [==============================] - 3s - loss: 1.9395e-05 - acc: 1.0000 - val_loss: 1.3744 - val_acc: 0.7931
Epoch 88/100
95/95 [==============================] - 3s - loss: 1.8681e-05 - acc: 1.0000 - val_loss: 1.3903 - val_acc: 0.7931
Epoch 89/100
95/95 [==============================] - 3s - loss: 1.8408e-05 - acc: 1.0000 - val_loss: 1.3831 - val_acc: 0.7931
Epoch 90/100
95/95 [==============================] - 3s - loss: 1.7104e-05 - acc: 1.0000 - val_loss: 1.3804 - val_acc: 0.7586
Epoch 91/100
95/95 [==============================] - 3s - loss: 1.6838e-05 - acc: 1.0000 - val_loss: 1.3835 - val_acc: 0.7931
Epoch 92/100
95/95 [==============================] - 3s - loss: 1.5819e-05 - acc: 1.0000 - val_loss: 1.3876 - val_acc: 0.7586
Epoch 93/100
95/95 [==============================] - 3s - loss: 1.5668e-05 - acc: 1.0000 - val_loss: 1.3930 - val_acc: 0.7931
Epoch 94/100
95/95 [==============================] - 3s - loss: 1.4771e-05 - acc: 1.0000 - val_loss: 1.3959 - val_acc: 0.7931
Epoch 95/100
95/95 [==============================] - 3s - loss: 1.4084e-05 - acc: 1.0000 - val_loss: 1.3964 - val_acc: 0.7586
Epoch 96/100
95/95 [==============================] - 3s - loss: 1.3611e-05 - acc: 1.0000 - val_loss: 1.4038 - val_acc: 0.7931
Epoch 97/100
95/95 [==============================] - 3s - loss: 1.3335e-05 - acc: 1.0000 - val_loss: 1.4050 - val_acc: 0.7931
Epoch 98/100
95/95 [==============================] - 3s - loss: 1.2935e-05 - acc: 1.0000 - val_loss: 1.4124 - val_acc: 0.7931
Epoch 99/100
95/95 [==============================] - 3s - loss: 1.2700e-05 - acc: 1.0000 - val_loss: 1.4093 - val_acc: 0.7931
Epoch 100/100
95/95 [==============================] - 3s - loss: 1.2193e-05 - acc: 1.0000 - val_loss: 1.4160 - val_acc: 0.7931
In [94]:
plot_training_curves(history.history);
score = model2_nodropoutropout.evaluate(X_test_norm, Y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
('Test loss:', 1.6642777919769287)
('Test accuracy:', 0.79310345649719238)

Indeed, without dropout the network quickly gets a perfect training score and does not generalize at all. It's really amazing that dropout lets us train ~180K parameters with only 95 training images in a generalizable way! (See this Geoff Hinton talk for some discussion of why this works...)

Novel data from the internet

Let's look at a few images from the internet. I found these manually. To be scientific, we'd want to get more data. For now, just getting an initial sense.

In [43]:
def image_from_url(url):
    response = requests.get(url)
    img = Image.open(StringIO(response.content))
    return img
In [86]:
hummer_urls = [
    "https://upload.wikimedia.org/wikipedia/commons/thumb/5/58/2002-09-11-Marbella-22.jpg/220px-2002-09-11-Marbella-22.jpg",
    "https://s-media-cache-ak0.pinimg.com/originals/83/01/a0/8301a0d1d45e4a53e1a790fe79f6cbef.jpg",
    "http://vignette1.wikia.nocookie.net/asphalt/images/2/26/Hummer_h1.jpg/revision/latest?cb=20150120143504",
   ]

acura_urls = [
    "https://media.ed.edmunds-media.com/acura/rl/2012/oem/2012_acura_rl_sedan_base_fq_oem_6_500.jpg",
    "https://media.ed.edmunds-media.com/acura/rl/2012/oem/2012_acura_rl_sedan_base_fq_oem_3_500.jpg",
    "https://upload.wikimedia.org/wikipedia/commons/thumb/f/fc/2005_Acura_RL_--_NHTSA.jpg/1200px-2005_Acura_RL_--_NHTSA.jpg"
]

hummers = map(image_from_url, hummer_urls)
acuras = map(image_from_url, acura_urls)
In [87]:
to_test = [img.resize((227,227))
           for img in hummers + acuras]
    
to_test = [np.array(img).astype('float32')/255.0
                    for img in to_test]

to_test = np.array(to_test)
to_test.shape
Out[87]:
(6, 227, 227, 3)
In [88]:
predicts = model2.predict(to_test)
ys = np.array([[1,0],[1,0],[1,0],
              [0,1],[0,1],[0,1]])
In [89]:
fig = plot_data(to_test, ys, predicts)
fig.suptitle("Internet test")
Out[89]:
<matplotlib.text.Text at 0x1506a6990>

It seems to work ok. Just for fun, let's make sure we don't already have these images -- the original dataset was probably collected by googling too.

In [96]:
def cmp_images(a,b):
    """
    Compare two images. Both must be numpy arrays, of same size, 
    float32 normalized to [0,1].
    Approximate...
    """
    if np.abs((a-b)).mean() < 0.1:  # hack, but should be good enough
        return True
    return False

# mini test
cmp_images(X_train_norm[0], X_train_norm[0]), cmp_images(
    X_train_norm[0], X_train_norm[1])
Out[96]:
(True, False)
In [91]:
stop = False
for a in it.chain(X_train_norm, X_valid_norm, X_test_norm):
    for b in to_test:
        if cmp_images(a,b):
            fig, plots = plt.subplots(1, 2)
            plots[0].imshow(a)
            plots[1].imshow(b)

A note on puppies

Note that our classifier divides the world into hummer and acura, and has no conception of anything else. Let's see what it does with puppies...

In [410]:
puppy_urls = [
    "http://cdn.earthporm.com/wp-content/uploads/2015/10/XX-Proud-Mommies5__605.jpg",
    "https://ipetcompanion.com/feedapuppy/styles/media/puppy.jpg",
    "https://i.ytimg.com/vi/PnY7WqoN4F8/hqdefault.jpg"
]

puppies = map(image_from_url, puppy_urls)
In [411]:
to_test = [img.resize((227,227))
           for img in puppies]
    
to_test = [np.array(img).astype('float32')/255.0
                    for img in to_test]

to_test = np.array(to_test)
to_test.shape
Out[411]:
(3, 227, 227, 3)
In [412]:
predicts = model2.predict(to_test)
# We'll pretend that puppies are hummers, just to make the code happy
ys = np.array([[1,0],[1,0],[1,0]])
fig = plot_data(to_test, ys, predicts)
fig.suptitle("Puppies test")
Out[412]:
<matplotlib.text.Text at 0x1acdfec10>

And voila -- we have two Hummer puppies and an acura puppy. The lesson: either ensure that you're only feeding your classifier data from one of the expected classes, or prepare to handle inputs that don't belong. We can include a none-of-the-above class, or predict the prob of each class independently, without using a softmax, so they can all be low at once. Either way, we'd need to feed our network enough negative data.

As a final step before moving on to a more complex problem, let's add TensorBoard so we can look at what's happening in our network.

  • Result: TensorBoard doesn't seem to work very well with Keras yet. Layer names are messed up, so graph looks wonky. Histograms seem to work ok though.
In [ ]:
 

Problem 2

Ok, now we have basic pipeline working... let's try a more complex problem:

  • classify images as sedan, SUV, pickup, or none of the above.
  • we'll use our Cars dataset, perhaps augmented with some random images

Create our new class hierarchy

As a first step, let's finish up our car model classification into sedan, SUV, etc.

In [133]:
classes
Out[133]:
[u'AM General Hummer SUV 2000',
 u'Acura RL Sedan 2012',
 u'Acura TL Sedan 2012',
 u'Acura TL Type-S 2008',
 u'Acura TSX Sedan 2012',
 u'Acura Integra Type R 2001',
 u'Acura ZDX Hatchback 2012',
 u'Aston Martin V8 Vantage Convertible 2012',
 u'Aston Martin V8 Vantage Coupe 2012',
 u'Aston Martin Virage Convertible 2012',
 u'Aston Martin Virage Coupe 2012',
 u'Audi RS 4 Convertible 2008',
 u'Audi A5 Coupe 2012',
 u'Audi TTS Coupe 2012',
 u'Audi R8 Coupe 2012',
 u'Audi V8 Sedan 1994',
 u'Audi 100 Sedan 1994',
 u'Audi 100 Wagon 1994',
 u'Audi TT Hatchback 2011',
 u'Audi S6 Sedan 2011',
 u'Audi S5 Convertible 2012',
 u'Audi S5 Coupe 2012',
 u'Audi S4 Sedan 2012',
 u'Audi S4 Sedan 2007',
 u'Audi TT RS Coupe 2012',
 u'BMW ActiveHybrid 5 Sedan 2012',
 u'BMW 1 Series Convertible 2012',
 u'BMW 1 Series Coupe 2012',
 u'BMW 3 Series Sedan 2012',
 u'BMW 3 Series Wagon 2012',
 u'BMW 6 Series Convertible 2007',
 u'BMW X5 SUV 2007',
 u'BMW X6 SUV 2012',
 u'BMW M3 Coupe 2012',
 u'BMW M5 Sedan 2010',
 u'BMW M6 Convertible 2010',
 u'BMW X3 SUV 2012',
 u'BMW Z4 Convertible 2012',
 u'Bentley Continental Supersports Conv. Convertible 2012',
 u'Bentley Arnage Sedan 2009',
 u'Bentley Mulsanne Sedan 2011',
 u'Bentley Continental GT Coupe 2012',
 u'Bentley Continental GT Coupe 2007',
 u'Bentley Continental Flying Spur Sedan 2007',
 u'Bugatti Veyron 16.4 Convertible 2009',
 u'Bugatti Veyron 16.4 Coupe 2009',
 u'Buick Regal GS 2012',
 u'Buick Rainier SUV 2007',
 u'Buick Verano Sedan 2012',
 u'Buick Enclave SUV 2012',
 u'Cadillac CTS-V Sedan 2012',
 u'Cadillac SRX SUV 2012',
 u'Cadillac Escalade EXT Crew Cab 2007',
 u'Chevrolet Silverado 1500 Hybrid Crew Cab 2012',
 u'Chevrolet Corvette Convertible 2012',
 u'Chevrolet Corvette ZR1 2012',
 u'Chevrolet Corvette Ron Fellows Edition Z06 2007',
 u'Chevrolet Traverse SUV 2012',
 u'Chevrolet Camaro Convertible 2012',
 u'Chevrolet HHR SS 2010',
 u'Chevrolet Impala Sedan 2007',
 u'Chevrolet Tahoe Hybrid SUV 2012',
 u'Chevrolet Sonic Sedan 2012',
 u'Chevrolet Express Cargo Van 2007',
 u'Chevrolet Avalanche Crew Cab 2012',
 u'Chevrolet Cobalt SS 2010',
 u'Chevrolet Malibu Hybrid Sedan 2010',
 u'Chevrolet TrailBlazer SS 2009',
 u'Chevrolet Silverado 2500HD Regular Cab 2012',
 u'Chevrolet Silverado 1500 Classic Extended Cab 2007',
 u'Chevrolet Express Van 2007',
 u'Chevrolet Monte Carlo Coupe 2007',
 u'Chevrolet Malibu Sedan 2007',
 u'Chevrolet Silverado 1500 Extended Cab 2012',
 u'Chevrolet Silverado 1500 Regular Cab 2012',
 u'Chrysler Aspen SUV 2009',
 u'Chrysler Sebring Convertible 2010',
 u'Chrysler Town and Country Minivan 2012',
 u'Chrysler 300 SRT-8 2010',
 u'Chrysler Crossfire Convertible 2008',
 u'Chrysler PT Cruiser Convertible 2008',
 u'Daewoo Nubira Wagon 2002',
 u'Dodge Caliber Wagon 2012',
 u'Dodge Caliber Wagon 2007',
 u'Dodge Caravan Minivan 1997',
 u'Dodge Ram Pickup 3500 Crew Cab 2010',
 u'Dodge Ram Pickup 3500 Quad Cab 2009',
 u'Dodge Sprinter Cargo Van 2009',
 u'Dodge Journey SUV 2012',
 u'Dodge Dakota Crew Cab 2010',
 u'Dodge Dakota Club Cab 2007',
 u'Dodge Magnum Wagon 2008',
 u'Dodge Challenger SRT8 2011',
 u'Dodge Durango SUV 2012',
 u'Dodge Durango SUV 2007',
 u'Dodge Charger Sedan 2012',
 u'Dodge Charger SRT-8 2009',
 u'Eagle Talon Hatchback 1998',
 u'FIAT 500 Abarth 2012',
 u'FIAT 500 Convertible 2012',
 u'Ferrari FF Coupe 2012',
 u'Ferrari California Convertible 2012',
 u'Ferrari 458 Italia Convertible 2012',
 u'Ferrari 458 Italia Coupe 2012',
 u'Fisker Karma Sedan 2012',
 u'Ford F-450 Super Duty Crew Cab 2012',
 u'Ford Mustang Convertible 2007',
 u'Ford Freestar Minivan 2007',
 u'Ford Expedition EL SUV 2009',
 u'Ford Edge SUV 2012',
 u'Ford Ranger SuperCab 2011',
 u'Ford GT Coupe 2006',
 u'Ford F-150 Regular Cab 2012',
 u'Ford F-150 Regular Cab 2007',
 u'Ford Focus Sedan 2007',
 u'Ford E-Series Wagon Van 2012',
 u'Ford Fiesta Sedan 2012',
 u'GMC Terrain SUV 2012',
 u'GMC Savana Van 2012',
 u'GMC Yukon Hybrid SUV 2012',
 u'GMC Acadia SUV 2012',
 u'GMC Canyon Extended Cab 2012',
 u'Geo Metro Convertible 1993',
 u'HUMMER H3T Crew Cab 2010',
 u'HUMMER H2 SUT Crew Cab 2009',
 u'Honda Odyssey Minivan 2012',
 u'Honda Odyssey Minivan 2007',
 u'Honda Accord Coupe 2012',
 u'Honda Accord Sedan 2012',
 u'Hyundai Veloster Hatchback 2012',
 u'Hyundai Santa Fe SUV 2012',
 u'Hyundai Tucson SUV 2012',
 u'Hyundai Veracruz SUV 2012',
 u'Hyundai Sonata Hybrid Sedan 2012',
 u'Hyundai Elantra Sedan 2007',
 u'Hyundai Accent Sedan 2012',
 u'Hyundai Genesis Sedan 2012',
 u'Hyundai Sonata Sedan 2012',
 u'Hyundai Elantra Touring Hatchback 2012',
 u'Hyundai Azera Sedan 2012',
 u'Infiniti G Coupe IPL 2012',
 u'Infiniti QX56 SUV 2011',
 u'Isuzu Ascender SUV 2008',
 u'Jaguar XK XKR 2012',
 u'Jeep Patriot SUV 2012',
 u'Jeep Wrangler SUV 2012',
 u'Jeep Liberty SUV 2012',
 u'Jeep Grand Cherokee SUV 2012',
 u'Jeep Compass SUV 2012',
 u'Lamborghini Reventon Coupe 2008',
 u'Lamborghini Aventador Coupe 2012',
 u'Lamborghini Gallardo LP 570-4 Superleggera 2012',
 u'Lamborghini Diablo Coupe 2001',
 u'Land Rover Range Rover SUV 2012',
 u'Land Rover LR2 SUV 2012',
 u'Lincoln Town Car Sedan 2011',
 u'MINI Cooper Roadster Convertible 2012',
 u'Maybach Landaulet Convertible 2012',
 u'Mazda Tribute SUV 2011',
 u'McLaren MP4-12C Coupe 2012',
 u'Mercedes-Benz 300-Class Convertible 1993',
 u'Mercedes-Benz C-Class Sedan 2012',
 u'Mercedes-Benz SL-Class Coupe 2009',
 u'Mercedes-Benz E-Class Sedan 2012',
 u'Mercedes-Benz S-Class Sedan 2012',
 u'Mercedes-Benz Sprinter Van 2012',
 u'Mitsubishi Lancer Sedan 2012',
 u'Nissan Leaf Hatchback 2012',
 u'Nissan NV Passenger Van 2012',
 u'Nissan Juke Hatchback 2012',
 u'Nissan 240SX Coupe 1998',
 u'Plymouth Neon Coupe 1999',
 u'Porsche Panamera Sedan 2012',
 u'Ram C/V Cargo Van Minivan 2012',
 u'Rolls-Royce Phantom Drophead Coupe Convertible 2012',
 u'Rolls-Royce Ghost Sedan 2012',
 u'Rolls-Royce Phantom Sedan 2012',
 u'Scion xD Hatchback 2012',
 u'Spyker C8 Convertible 2009',
 u'Spyker C8 Coupe 2009',
 u'Suzuki Aerio Sedan 2007',
 u'Suzuki Kizashi Sedan 2012',
 u'Suzuki SX4 Hatchback 2012',
 u'Suzuki SX4 Sedan 2012',
 u'Tesla Model S Sedan 2012',
 u'Toyota Sequoia SUV 2012',
 u'Toyota Camry Sedan 2012',
 u'Toyota Corolla Sedan 2012',
 u'Toyota 4Runner SUV 2012',
 u'Volkswagen Golf Hatchback 2012',
 u'Volkswagen Golf Hatchback 1991',
 u'Volkswagen Beetle Hatchback 2012',
 u'Volvo C30 Hatchback 2012',
 u'Volvo 240 Sedan 1993',
 u'Volvo XC90 SUV 2007',
 u'smart fortwo Convertible 2012']

Looks good. Now let's split into (brand, model, type, year tuples). This will require a bit of manual munging...

In [44]:
# pull out brands -- will need a bit of massaging for two-word names
# like aston martin and land rover
pd.unique([c.split()[0] for c in classes])
Out[44]:
array([u'AM', u'Acura', u'Aston', u'Audi', u'BMW', u'Bentley', u'Bugatti',
       u'Buick', u'Cadillac', u'Chevrolet', u'Chrysler', u'Daewoo',
       u'Dodge', u'Eagle', u'FIAT', u'Ferrari', u'Fisker', u'Ford', u'GMC',
       u'Geo', u'HUMMER', u'Honda', u'Hyundai', u'Infiniti', u'Isuzu',
       u'Jaguar', u'Jeep', u'Lamborghini', u'Land', u'Lincoln', u'MINI',
       u'Maybach', u'Mazda', u'McLaren', u'Mercedes-Benz', u'Mitsubishi',
       u'Nissan', u'Plymouth', u'Porsche', u'Ram', u'Rolls-Royce',
       u'Scion', u'Spyker', u'Suzuki', u'Tesla', u'Toyota', u'Volkswagen',
       u'Volvo', u'smart'], dtype=object)
In [45]:
# pull out car types
pd.unique([c.split()[-2] for c in classes])
Out[45]:
array([u'SUV', u'Sedan', u'Type-S', u'R', u'Hatchback', u'Convertible',
       u'Coupe', u'Wagon', u'GS', u'Cab', u'ZR1', u'Z06', u'SS', u'Van',
       u'Minivan', u'SRT-8', u'SRT8', u'Abarth', u'SuperCab', u'IPL',
       u'XKR', u'Superleggera'], dtype=object)
In [46]:
def parse_classes(classes):
    """
    Return (id, brand, model, type, year) tuples.
    Type will be one of:
       * Sedan
       * Convertible
       * Coupe
       * SUV
       * Pickup
       * Van
       * Wagon
       
    We may combine these further later...
    """
    brands = [u'AM', u'Acura', u'Aston Martin', u'Audi', u'BMW', u'Bentley', u'Bugatti',
       u'Buick', u'Cadillac', u'Chevrolet', u'Chrysler', u'Daewoo',
       u'Dodge', u'Eagle', u'FIAT', u'Ferrari', u'Fisker', u'Ford', u'GMC',
       u'Geo', u'HUMMER', u'Honda', u'Hyundai', u'Infiniti', u'Isuzu',
       u'Jaguar', u'Jeep', u'Lamborghini', u'Land Rover', u'Lincoln', u'MINI',
       u'Maybach', u'Mazda', u'McLaren', u'Mercedes-Benz', u'Mitsubishi',
       u'Nissan', u'Plymouth', u'Porsche', u'Ram', u'Rolls-Royce',
       u'Scion', u'Spyker', u'Suzuki', u'Tesla', u'Toyota', u'Volkswagen',
       u'Volvo', u'smart']
    
    car_types = ['SUV', 'Sedan', 'Convertible', 'Coupe', 
                 'Wagon', 'Minivan', 'Van']
    # Some types we'll just remap
    # (Used Google image search to figure out some of these).
    car_type_map = {'Hatchback': 'Wagon',
                    'SuperCab' : 'Pickup',
                    'Type-S' : 'Sedan',
                    'Type R' : 'Coupe',
                    'Cab' : 'Pickup',
                    'GS': 'Sedan',
                    'ZR1' : 'Coupe',
                    'Z06' : 'Coupe',
                    'HHR SS' : 'SUV',
                    'Cobalt SS': 'Coupe',
                    'TrailBlazer SS': 'SUV',
                    '300 SRT-8': 'Sedan',
                    'Challenger SRT8' : 'Coupe',
                    'Charger SRT-8': 'Sedan',
                    # different style coupe -- I expect difficulties with
                    # the fiat 500...
                    '500 Abarth': 'Coupe',
                    'Coupe IPL': 'Coupe',
                    'XKR' : 'Coupe',
                    'Superleggera' : 'Coupe'
                   }
                    
    ret = []                
    for i, cls in enumerate(classes):
        words = cls.split()
        year = words[-1]
        if words[-2] in car_types:
            car_type = words[-2]
        elif words[-2] in car_type_map:
            car_type = car_type_map[words[-2]]
        else:
            # look for last two words match
            key = words[-3] + " " + words[-2]
            if key in car_type_map:
                car_type = car_type_map[key]
            else:
                print("Unknown car type: ", cls)
        
        # make sure everything is unicode to avoid any
        # issues later
        car_type = unicode(car_type)
        
        # just search
        for b in brands:
            if b in cls:
                brand = b
                brand_len = len(brand.split())

        # this will be approximate, but I don't really care
        # about the model, at least for now
        model = " ".join(words[brand_len:-2])
        # Careful! Class ids start at 1
        ret.append((i+1, brand, model, car_type, year))

    return ret

cls_tuples = parse_classes(classes)
cls_tuples
Out[46]:
[(1, u'AM', u'General Hummer', u'SUV', u'2000'),
 (2, u'Acura', u'RL', u'Sedan', u'2012'),
 (3, u'Acura', u'TL', u'Sedan', u'2012'),
 (4, u'Acura', u'TL', u'Sedan', u'2008'),
 (5, u'Acura', u'TSX', u'Sedan', u'2012'),
 (6, u'Acura', u'Integra Type', u'Coupe', u'2001'),
 (7, u'Acura', u'ZDX', u'Wagon', u'2012'),
 (8, u'Aston Martin', u'V8 Vantage', u'Convertible', u'2012'),
 (9, u'Aston Martin', u'V8 Vantage', u'Coupe', u'2012'),
 (10, u'Aston Martin', u'Virage', u'Convertible', u'2012'),
 (11, u'Aston Martin', u'Virage', u'Coupe', u'2012'),
 (12, u'Audi', u'RS 4', u'Convertible', u'2008'),
 (13, u'Audi', u'A5', u'Coupe', u'2012'),
 (14, u'Audi', u'TTS', u'Coupe', u'2012'),
 (15, u'Audi', u'R8', u'Coupe', u'2012'),
 (16, u'Audi', u'V8', u'Sedan', u'1994'),
 (17, u'Audi', u'100', u'Sedan', u'1994'),
 (18, u'Audi', u'100', u'Wagon', u'1994'),
 (19, u'Audi', u'TT', u'Wagon', u'2011'),
 (20, u'Audi', u'S6', u'Sedan', u'2011'),
 (21, u'Audi', u'S5', u'Convertible', u'2012'),
 (22, u'Audi', u'S5', u'Coupe', u'2012'),
 (23, u'Audi', u'S4', u'Sedan', u'2012'),
 (24, u'Audi', u'S4', u'Sedan', u'2007'),
 (25, u'Audi', u'TT RS', u'Coupe', u'2012'),
 (26, u'BMW', u'ActiveHybrid 5', u'Sedan', u'2012'),
 (27, u'BMW', u'1 Series', u'Convertible', u'2012'),
 (28, u'BMW', u'1 Series', u'Coupe', u'2012'),
 (29, u'BMW', u'3 Series', u'Sedan', u'2012'),
 (30, u'BMW', u'3 Series', u'Wagon', u'2012'),
 (31, u'BMW', u'6 Series', u'Convertible', u'2007'),
 (32, u'BMW', u'X5', u'SUV', u'2007'),
 (33, u'BMW', u'X6', u'SUV', u'2012'),
 (34, u'BMW', u'M3', u'Coupe', u'2012'),
 (35, u'BMW', u'M5', u'Sedan', u'2010'),
 (36, u'BMW', u'M6', u'Convertible', u'2010'),
 (37, u'BMW', u'X3', u'SUV', u'2012'),
 (38, u'BMW', u'Z4', u'Convertible', u'2012'),
 (39, u'Bentley', u'Continental Supersports Conv.', u'Convertible', u'2012'),
 (40, u'Bentley', u'Arnage', u'Sedan', u'2009'),
 (41, u'Bentley', u'Mulsanne', u'Sedan', u'2011'),
 (42, u'Bentley', u'Continental GT', u'Coupe', u'2012'),
 (43, u'Bentley', u'Continental GT', u'Coupe', u'2007'),
 (44, u'Bentley', u'Continental Flying Spur', u'Sedan', u'2007'),
 (45, u'Bugatti', u'Veyron 16.4', u'Convertible', u'2009'),
 (46, u'Bugatti', u'Veyron 16.4', u'Coupe', u'2009'),
 (47, u'Buick', u'Regal', u'Sedan', u'2012'),
 (48, u'Buick', u'Rainier', u'SUV', u'2007'),
 (49, u'Buick', u'Verano', u'Sedan', u'2012'),
 (50, u'Buick', u'Enclave', u'SUV', u'2012'),
 (51, u'Cadillac', u'CTS-V', u'Sedan', u'2012'),
 (52, u'Cadillac', u'SRX', u'SUV', u'2012'),
 (53, u'Cadillac', u'Escalade EXT Crew', u'Pickup', u'2007'),
 (54, u'Chevrolet', u'Silverado 1500 Hybrid Crew', u'Pickup', u'2012'),
 (55, u'Chevrolet', u'Corvette', u'Convertible', u'2012'),
 (56, u'Chevrolet', u'Corvette', u'Coupe', u'2012'),
 (57, u'Chevrolet', u'Corvette Ron Fellows Edition', u'Coupe', u'2007'),
 (58, u'Chevrolet', u'Traverse', u'SUV', u'2012'),
 (59, u'Chevrolet', u'Camaro', u'Convertible', u'2012'),
 (60, u'Chevrolet', u'HHR', u'SUV', u'2010'),
 (61, u'Chevrolet', u'Impala', u'Sedan', u'2007'),
 (62, u'Chevrolet', u'Tahoe Hybrid', u'SUV', u'2012'),
 (63, u'Chevrolet', u'Sonic', u'Sedan', u'2012'),
 (64, u'Chevrolet', u'Express Cargo', u'Van', u'2007'),
 (65, u'Chevrolet', u'Avalanche Crew', u'Pickup', u'2012'),
 (66, u'Chevrolet', u'Cobalt', u'Coupe', u'2010'),
 (67, u'Chevrolet', u'Malibu Hybrid', u'Sedan', u'2010'),
 (68, u'Chevrolet', u'TrailBlazer', u'SUV', u'2009'),
 (69, u'Chevrolet', u'Silverado 2500HD Regular', u'Pickup', u'2012'),
 (70, u'Chevrolet', u'Silverado 1500 Classic Extended', u'Pickup', u'2007'),
 (71, u'Chevrolet', u'Express', u'Van', u'2007'),
 (72, u'Chevrolet', u'Monte Carlo', u'Coupe', u'2007'),
 (73, u'Chevrolet', u'Malibu', u'Sedan', u'2007'),
 (74, u'Chevrolet', u'Silverado 1500 Extended', u'Pickup', u'2012'),
 (75, u'Chevrolet', u'Silverado 1500 Regular', u'Pickup', u'2012'),
 (76, u'Chrysler', u'Aspen', u'SUV', u'2009'),
 (77, u'Chrysler', u'Sebring', u'Convertible', u'2010'),
 (78, u'Chrysler', u'Town and Country', u'Minivan', u'2012'),
 (79, u'Chrysler', u'300', u'Sedan', u'2010'),
 (80, u'Chrysler', u'Crossfire', u'Convertible', u'2008'),
 (81, u'Chrysler', u'PT Cruiser', u'Convertible', u'2008'),
 (82, u'Daewoo', u'Nubira', u'Wagon', u'2002'),
 (83, u'Dodge', u'Caliber', u'Wagon', u'2012'),
 (84, u'Dodge', u'Caliber', u'Wagon', u'2007'),
 (85, u'Dodge', u'Caravan', u'Minivan', u'1997'),
 (86, u'Ram', u'Ram Pickup 3500 Crew', u'Pickup', u'2010'),
 (87, u'Ram', u'Ram Pickup 3500 Quad', u'Pickup', u'2009'),
 (88, u'Dodge', u'Sprinter Cargo', u'Van', u'2009'),
 (89, u'Dodge', u'Journey', u'SUV', u'2012'),
 (90, u'Dodge', u'Dakota Crew', u'Pickup', u'2010'),
 (91, u'Dodge', u'Dakota Club', u'Pickup', u'2007'),
 (92, u'Dodge', u'Magnum', u'Wagon', u'2008'),
 (93, u'Dodge', u'Challenger', u'Coupe', u'2011'),
 (94, u'Dodge', u'Durango', u'SUV', u'2012'),
 (95, u'Dodge', u'Durango', u'SUV', u'2007'),
 (96, u'Dodge', u'Charger', u'Sedan', u'2012'),
 (97, u'Dodge', u'Charger', u'Sedan', u'2009'),
 (98, u'Eagle', u'Talon', u'Wagon', u'1998'),
 (99, u'FIAT', u'500', u'Coupe', u'2012'),
 (100, u'FIAT', u'500', u'Convertible', u'2012'),
 (101, u'Ferrari', u'FF', u'Coupe', u'2012'),
 (102, u'Ferrari', u'California', u'Convertible', u'2012'),
 (103, u'Ferrari', u'458 Italia', u'Convertible', u'2012'),
 (104, u'Ferrari', u'458 Italia', u'Coupe', u'2012'),
 (105, u'Fisker', u'Karma', u'Sedan', u'2012'),
 (106, u'Ford', u'F-450 Super Duty Crew', u'Pickup', u'2012'),
 (107, u'Ford', u'Mustang', u'Convertible', u'2007'),
 (108, u'Ford', u'Freestar', u'Minivan', u'2007'),
 (109, u'Ford', u'Expedition EL', u'SUV', u'2009'),
 (110, u'Ford', u'Edge', u'SUV', u'2012'),
 (111, u'Ford', u'Ranger', u'Pickup', u'2011'),
 (112, u'Ford', u'GT', u'Coupe', u'2006'),
 (113, u'Ford', u'F-150 Regular', u'Pickup', u'2012'),
 (114, u'Ford', u'F-150 Regular', u'Pickup', u'2007'),
 (115, u'Ford', u'Focus', u'Sedan', u'2007'),
 (116, u'Ford', u'E-Series Wagon', u'Van', u'2012'),
 (117, u'Ford', u'Fiesta', u'Sedan', u'2012'),
 (118, u'GMC', u'Terrain', u'SUV', u'2012'),
 (119, u'GMC', u'Savana', u'Van', u'2012'),
 (120, u'GMC', u'Yukon Hybrid', u'SUV', u'2012'),
 (121, u'GMC', u'Acadia', u'SUV', u'2012'),
 (122, u'GMC', u'Canyon Extended', u'Pickup', u'2012'),
 (123, u'Geo', u'Metro', u'Convertible', u'1993'),
 (124, u'HUMMER', u'H3T Crew', u'Pickup', u'2010'),
 (125, u'HUMMER', u'H2 SUT Crew', u'Pickup', u'2009'),
 (126, u'Honda', u'Odyssey', u'Minivan', u'2012'),
 (127, u'Honda', u'Odyssey', u'Minivan', u'2007'),
 (128, u'Honda', u'Accord', u'Coupe', u'2012'),
 (129, u'Honda', u'Accord', u'Sedan', u'2012'),
 (130, u'Hyundai', u'Veloster', u'Wagon', u'2012'),
 (131, u'Hyundai', u'Santa Fe', u'SUV', u'2012'),
 (132, u'Hyundai', u'Tucson', u'SUV', u'2012'),
 (133, u'Hyundai', u'Veracruz', u'SUV', u'2012'),
 (134, u'Hyundai', u'Sonata Hybrid', u'Sedan', u'2012'),
 (135, u'Hyundai', u'Elantra', u'Sedan', u'2007'),
 (136, u'Hyundai', u'Accent', u'Sedan', u'2012'),
 (137, u'Hyundai', u'Genesis', u'Sedan', u'2012'),
 (138, u'Hyundai', u'Sonata', u'Sedan', u'2012'),
 (139, u'Hyundai', u'Elantra Touring', u'Wagon', u'2012'),
 (140, u'Hyundai', u'Azera', u'Sedan', u'2012'),
 (141, u'Infiniti', u'G Coupe', u'Coupe', u'2012'),
 (142, u'Infiniti', u'QX56', u'SUV', u'2011'),
 (143, u'Isuzu', u'Ascender', u'SUV', u'2008'),
 (144, u'Jaguar', u'XK', u'Coupe', u'2012'),
 (145, u'Jeep', u'Patriot', u'SUV', u'2012'),
 (146, u'Jeep', u'Wrangler', u'SUV', u'2012'),
 (147, u'Jeep', u'Liberty', u'SUV', u'2012'),
 (148, u'Jeep', u'Grand Cherokee', u'SUV', u'2012'),
 (149, u'Jeep', u'Compass', u'SUV', u'2012'),
 (150, u'Lamborghini', u'Reventon', u'Coupe', u'2008'),
 (151, u'Lamborghini', u'Aventador', u'Coupe', u'2012'),
 (152, u'Lamborghini', u'Gallardo LP 570-4', u'Coupe', u'2012'),
 (153, u'Lamborghini', u'Diablo', u'Coupe', u'2001'),
 (154, u'Land Rover', u'Range Rover', u'SUV', u'2012'),
 (155, u'Land Rover', u'LR2', u'SUV', u'2012'),
 (156, u'Lincoln', u'Town Car', u'Sedan', u'2011'),
 (157, u'MINI', u'Cooper Roadster', u'Convertible', u'2012'),
 (158, u'Maybach', u'Landaulet', u'Convertible', u'2012'),
 (159, u'Mazda', u'Tribute', u'SUV', u'2011'),
 (160, u'McLaren', u'MP4-12C', u'Coupe', u'2012'),
 (161, u'Mercedes-Benz', u'300-Class', u'Convertible', u'1993'),
 (162, u'Mercedes-Benz', u'C-Class', u'Sedan', u'2012'),
 (163, u'Mercedes-Benz', u'SL-Class', u'Coupe', u'2009'),
 (164, u'Mercedes-Benz', u'E-Class', u'Sedan', u'2012'),
 (165, u'Mercedes-Benz', u'S-Class', u'Sedan', u'2012'),
 (166, u'Mercedes-Benz', u'Sprinter', u'Van', u'2012'),
 (167, u'Mitsubishi', u'Lancer', u'Sedan', u'2012'),
 (168, u'Nissan', u'Leaf', u'Wagon', u'2012'),
 (169, u'Nissan', u'NV Passenger', u'Van', u'2012'),
 (170, u'Nissan', u'Juke', u'Wagon', u'2012'),
 (171, u'Nissan', u'240SX', u'Coupe', u'1998'),
 (172, u'Plymouth', u'Neon', u'Coupe', u'1999'),
 (173, u'Porsche', u'Panamera', u'Sedan', u'2012'),
 (174, u'Ram', u'C/V Cargo Van', u'Minivan', u'2012'),
 (175, u'Rolls-Royce', u'Phantom Drophead Coupe', u'Convertible', u'2012'),
 (176, u'Rolls-Royce', u'Ghost', u'Sedan', u'2012'),
 (177, u'Rolls-Royce', u'Phantom', u'Sedan', u'2012'),
 (178, u'Scion', u'xD', u'Wagon', u'2012'),
 (179, u'Spyker', u'C8', u'Convertible', u'2009'),
 (180, u'Spyker', u'C8', u'Coupe', u'2009'),
 (181, u'Suzuki', u'Aerio', u'Sedan', u'2007'),
 (182, u'Suzuki', u'Kizashi', u'Sedan', u'2012'),
 (183, u'Suzuki', u'SX4', u'Wagon', u'2012'),
 (184, u'Suzuki', u'SX4', u'Sedan', u'2012'),
 (185, u'Tesla', u'Model S', u'Sedan', u'2012'),
 (186, u'Toyota', u'Sequoia', u'SUV', u'2012'),
 (187, u'Toyota', u'Camry', u'Sedan', u'2012'),
 (188, u'Toyota', u'Corolla', u'Sedan', u'2012'),
 (189, u'Toyota', u'4Runner', u'SUV', u'2012'),
 (190, u'Volkswagen', u'Golf', u'Wagon', u'2012'),
 (191, u'Volkswagen', u'Golf', u'Wagon', u'1991'),
 (192, u'Volkswagen', u'Beetle', u'Wagon', u'2012'),
 (193, u'Volvo', u'C30', u'Wagon', u'2012'),
 (194, u'Volvo', u'240', u'Sedan', u'1993'),
 (195, u'Volvo', u'XC90', u'SUV', u'2007'),
 (196, u'smart', u'fortwo', u'Convertible', u'2012')]
In [47]:
by_car_type = {} # type -> [tuples]
key_fn = operator.itemgetter(3)
for ct, group in it.groupby(sorted(cls_tuples, key=key_fn), key_fn):
    by_car_type[ct] = list(group)

for k,v in sorted(by_car_type.items(), key=lambda x: len(x[1])):
    print(k, len(v))
(u'Minivan', 6)
(u'Van', 7)
(u'Pickup', 18)
(u'Wagon', 19)
(u'Convertible', 26)
(u'Coupe', 34)
(u'SUV', 36)
(u'Sedan', 50)

Hmm. Uneven number of car models per class. May make classification harder.

Ok, let's just try it with these classes -- load the images per class, give them appropriate labels.

To make the model smaller, let's start with just 20 images per car.

In [48]:
# make a mapping from class name to numbers we can one-hot encode
macro_classes = sorted(pd.unique([c[3] for c in cls_tuples]))
macro_class_map = dict((v,k) for (k,v) in enumerate(macro_classes))
macro_class_map
Out[48]:
{u'Convertible': 0,
 u'Coupe': 1,
 u'Minivan': 2,
 u'Pickup': 3,
 u'SUV': 4,
 u'Sedan': 5,
 u'Van': 6,
 u'Wagon': 7}

Let's save some metadata so we can load it in other notebooks -- this one is getting rather unwieldy.

In [66]:
import cPickle as pickle

save = True
if save:
    with open('class_details.pkl','w') as f:
        pickle.dump({'classes' : classes,
                     'examples' : examples,
                     'by_class' : by_class,
                     'by_car_type' : by_car_type,
                    'macro_classes' : macro_classes,
                    'macro_class_map': macro_class_map,
                    'cls_tuples': cls_tuples}, f)

Load and split the data

In [49]:
IMG_PER_CAR = None # 20 # None to use all
valid_frac = 0.2
test_frac = 0.2

train = []
valid = []
test = []
for car_type, model_tuples in by_car_type.items():
    macro_class_id = macro_class_map[car_type]
    
    for model_tpl in model_tuples:
        cls = model_tpl[0]
        examples = load_examples(by_class, cls, limit=IMG_PER_CAR)
        # replace class labels with the id of the macro class
        examples = [(X, macro_class_id) for (X,y) in examples]
        # split each class separately, so all have same fractions of 
        # train/valid/test
        (cls_train, cls_valid, cls_test) = split_examples(
            examples,
            valid_frac, test_frac)
        # and add them to the overall train/valid/test sets
        train.extend(cls_train)
        valid.extend(cls_valid)
        test.extend(cls_test)

# ...and shuffle to make training work better.
np.random.shuffle(train)
np.random.shuffle(valid)
np.random.shuffle(test)

Still copy-pasting from above. Could refactor...

In [50]:
# We have lists of (X,Y) tuples. Let's unzip into lists of Xs and Ys.
X_train, Y_train = zip(*train)
X_valid, Y_valid = zip(*valid)
X_test, Y_test = zip(*test)

# and turn into np arrays of the right dimension.
def convert_X(xs):
    '''
    Take list of (w,h,3) images.
    Turn into an np array, change type to float32.
    '''
    return np.array(xs).astype('float32')
    
X_train = convert_X(X_train)
X_valid = convert_X(X_valid)
X_test = convert_X(X_test)
In [51]:
X_train.shape
Out[51]:
(9867, 227, 227, 3)

Convert to one-hot

In [53]:
def convert_Y(ys, macro_classes):
    '''
    Convert to np array, make one-hot.
    Already ensured they're sequential from zero.
    '''
    n_classes = len(macro_classes)
    return np_utils.to_categorical(ys, n_classes)

Y_train = convert_Y(Y_train, macro_classes)
Y_valid = convert_Y(Y_valid, macro_classes)
Y_test = convert_Y(Y_test, macro_classes)
In [54]:
Y_train.shape
Out[54]:
(9867, 8)
In [55]:
# normalize the data, this time leaving it in color
X_train_norm = normalize_for_cnn(X_train)
X_valid_norm = normalize_for_cnn(X_valid)
X_test_norm = normalize_for_cnn(X_test)

Build our model.

In [167]:
# Let's use more or less the same model to start (num classes changes)
def cnn_model2(use_dropout=True):
    model = Sequential()
    nb_filters = 16
    pool_size = (2,2)
    filter_size = 3
    nb_classes = len(macro_classes)
    
    with tf.name_scope("conv1") as scope:
        model.add(Convolution2D(nb_filters, filter_size, 
                            input_shape=(227, 227, 3)))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=pool_size))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("conv2") as scope:
        model.add(Convolution2D(nb_filters, filter_size))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=pool_size))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("conv3") as scope:
        model.add(Convolution2D(nb_filters, filter_size))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=pool_size))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("dense1") as scope:
        model.add(Flatten())
        model.add(Dense(16))
        model.add(Activation('relu'))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("softmax") as scope:
        model.add(Dense(nb_classes))
        model.add(Activation('softmax'))
    return model

# Uncomment if getting a "Invalid argument: You must feed a value
# for placeholder tensor ..." when rerunning training. 
# K.clear_session() # https://github.com/fchollet/keras/issues/4499
    

model3 = cnn_model2()
model3.compile(loss='categorical_crossentropy',
              optimizer='adadelta',
              metrics=['accuracy'])
In [168]:
# This model will train slowly, so let's checkpoint it periodically
from keras.callbacks import ModelCheckpoint

Train the model...

In [169]:
recompute = False

if recompute:
#     # Save info during computation so we can see what's happening
#     tbCallback = TensorBoard(
#         log_dir='./graph', histogram_freq=1, 
#         write_graph=False, write_images=False)

    checkpoint = ModelCheckpoint('macro_class_cnn_checkpoint.5',
                                 monitor='val_acc',
                                 verbose=1,
                                 save_best_only=True, mode='max',
                                 save_weights_only=True)

    # Fit the model! Using a bigger batch size and fewer epochs
    # because we have ~10K training images now instead of 100.
    history = model3.fit(
        X_train_norm, Y_train,
        batch_size=64, nb_epoch=50, verbose=1,
        validation_data=(X_valid_norm, Y_valid),
        callbacks=[checkpoint]
    )
else:
    model3.load_weights('macro_class_cnn.h5')
Train on 9867 samples, validate on 3159 samples
Epoch 1/50
9856/9867 [============================>.] - ETA: 0s - loss: 2.0057 - acc: 0.2244Epoch 00000: val_acc improved from -inf to 0.25578, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 471s - loss: 2.0055 - acc: 0.2247 - val_loss: 2.0049 - val_acc: 0.2558
Epoch 2/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.9588 - acc: 0.2312Epoch 00001: val_acc did not improve
9867/9867 [==============================] - 393s - loss: 1.9587 - acc: 0.2312 - val_loss: 1.9725 - val_acc: 0.1912
Epoch 3/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.9288 - acc: 0.2420Epoch 00002: val_acc improved from 0.25578 to 0.29946, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 377s - loss: 1.9287 - acc: 0.2419 - val_loss: 1.9587 - val_acc: 0.2995
Epoch 4/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.9061 - acc: 0.2589Epoch 00003: val_acc did not improve
9867/9867 [==============================] - 363s - loss: 1.9059 - acc: 0.2587 - val_loss: 1.8863 - val_acc: 0.2881
Epoch 5/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8896 - acc: 0.2719Epoch 00004: val_acc did not improve
9867/9867 [==============================] - 356s - loss: 1.8895 - acc: 0.2719 - val_loss: 1.8866 - val_acc: 0.2919
Epoch 6/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8780 - acc: 0.2731Epoch 00005: val_acc did not improve
9867/9867 [==============================] - 354s - loss: 1.8785 - acc: 0.2729 - val_loss: 1.9138 - val_acc: 0.2991
Epoch 7/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8672 - acc: 0.2819Epoch 00006: val_acc did not improve
9867/9867 [==============================] - 360s - loss: 1.8673 - acc: 0.2817 - val_loss: 1.8759 - val_acc: 0.2906
Epoch 8/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8498 - acc: 0.2870Epoch 00007: val_acc did not improve
9867/9867 [==============================] - 355s - loss: 1.8497 - acc: 0.2871 - val_loss: 1.8528 - val_acc: 0.2947
Epoch 9/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8514 - acc: 0.2878Epoch 00008: val_acc did not improve
9867/9867 [==============================] - 357s - loss: 1.8517 - acc: 0.2878 - val_loss: 1.8778 - val_acc: 0.2868
Epoch 10/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8385 - acc: 0.2894Epoch 00009: val_acc improved from 0.29946 to 0.30263, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 355s - loss: 1.8387 - acc: 0.2892 - val_loss: 1.8588 - val_acc: 0.3026
Epoch 11/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8294 - acc: 0.2960Epoch 00010: val_acc improved from 0.30263 to 0.30263, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 353s - loss: 1.8293 - acc: 0.2958 - val_loss: 1.8331 - val_acc: 0.3026
Epoch 12/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8220 - acc: 0.2987Epoch 00011: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.8220 - acc: 0.2987 - val_loss: 1.8173 - val_acc: 0.2887
Epoch 13/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8228 - acc: 0.2907Epoch 00012: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.8226 - acc: 0.2906 - val_loss: 1.8132 - val_acc: 0.2969
Epoch 14/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8164 - acc: 0.2928Epoch 00013: val_acc improved from 0.30263 to 0.30358, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 353s - loss: 1.8166 - acc: 0.2928 - val_loss: 1.8181 - val_acc: 0.3036
Epoch 15/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8111 - acc: 0.2984Epoch 00014: val_acc did not improve
9867/9867 [==============================] - 355s - loss: 1.8112 - acc: 0.2984 - val_loss: 1.8189 - val_acc: 0.2963
Epoch 16/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8065 - acc: 0.2958Epoch 00015: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.8066 - acc: 0.2959 - val_loss: 1.7982 - val_acc: 0.3010
Epoch 17/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8023 - acc: 0.2998Epoch 00016: val_acc improved from 0.30358 to 0.30611, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 354s - loss: 1.8021 - acc: 0.3000 - val_loss: 1.7957 - val_acc: 0.3061
Epoch 18/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.8020 - acc: 0.3019Epoch 00017: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.8021 - acc: 0.3019 - val_loss: 1.8046 - val_acc: 0.3001
Epoch 19/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7954 - acc: 0.3064Epoch 00018: val_acc improved from 0.30611 to 0.31244, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 352s - loss: 1.7955 - acc: 0.3064 - val_loss: 1.8051 - val_acc: 0.3124
Epoch 20/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7918 - acc: 0.3104Epoch 00019: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.7919 - acc: 0.3102 - val_loss: 1.7904 - val_acc: 0.3033
Epoch 21/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7891 - acc: 0.3123Epoch 00020: val_acc improved from 0.31244 to 0.31624, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 352s - loss: 1.7892 - acc: 0.3122 - val_loss: 1.7877 - val_acc: 0.3162
Epoch 22/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7813 - acc: 0.3081Epoch 00021: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.7814 - acc: 0.3080 - val_loss: 1.7718 - val_acc: 0.3134
Epoch 23/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7799 - acc: 0.3072Epoch 00022: val_acc improved from 0.31624 to 0.32605, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 353s - loss: 1.7802 - acc: 0.3072 - val_loss: 1.7871 - val_acc: 0.3261
Epoch 24/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7738 - acc: 0.3132Epoch 00023: val_acc did not improve
9867/9867 [==============================] - 358s - loss: 1.7737 - acc: 0.3133 - val_loss: 1.7769 - val_acc: 0.3159
Epoch 25/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7690 - acc: 0.3162Epoch 00024: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.7690 - acc: 0.3160 - val_loss: 1.7700 - val_acc: 0.3191
Epoch 26/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7686 - acc: 0.3123Epoch 00025: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.7685 - acc: 0.3123 - val_loss: 1.7652 - val_acc: 0.3102
Epoch 27/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7620 - acc: 0.3150Epoch 00026: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.7619 - acc: 0.3152 - val_loss: 1.7581 - val_acc: 0.3029
Epoch 28/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7589 - acc: 0.3155Epoch 00027: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.7589 - acc: 0.3155 - val_loss: 1.7608 - val_acc: 0.3166
Epoch 29/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7576 - acc: 0.3199Epoch 00028: val_acc improved from 0.32605 to 0.33428, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 353s - loss: 1.7580 - acc: 0.3196 - val_loss: 1.7608 - val_acc: 0.3343
Epoch 30/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7556 - acc: 0.3171Epoch 00029: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.7561 - acc: 0.3169 - val_loss: 1.7523 - val_acc: 0.3276
Epoch 31/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7425 - acc: 0.3253Epoch 00030: val_acc improved from 0.33428 to 0.33555, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 352s - loss: 1.7424 - acc: 0.3254 - val_loss: 1.7470 - val_acc: 0.3355
Epoch 32/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7421 - acc: 0.3319Epoch 00031: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.7419 - acc: 0.3318 - val_loss: 1.7477 - val_acc: 0.3159
Epoch 33/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7389 - acc: 0.3254Epoch 00032: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.7386 - acc: 0.3254 - val_loss: 1.7568 - val_acc: 0.3147
Epoch 34/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7385 - acc: 0.3332Epoch 00033: val_acc did not improve
9867/9867 [==============================] - 354s - loss: 1.7384 - acc: 0.3330 - val_loss: 1.7594 - val_acc: 0.3197
Epoch 35/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7297 - acc: 0.3287Epoch 00034: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.7295 - acc: 0.3290 - val_loss: 1.7310 - val_acc: 0.3343
Epoch 36/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7263 - acc: 0.3331Epoch 00035: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.7263 - acc: 0.3332 - val_loss: 1.7426 - val_acc: 0.3314
Epoch 37/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7257 - acc: 0.3361Epoch 00036: val_acc improved from 0.33555 to 0.34156, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 353s - loss: 1.7258 - acc: 0.3360 - val_loss: 1.7333 - val_acc: 0.3416
Epoch 38/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7255 - acc: 0.3344Epoch 00037: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.7256 - acc: 0.3343 - val_loss: 1.7354 - val_acc: 0.3346
Epoch 39/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7188 - acc: 0.3360Epoch 00038: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.7191 - acc: 0.3358 - val_loss: 1.7265 - val_acc: 0.3416
Epoch 40/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7179 - acc: 0.3429Epoch 00039: val_acc improved from 0.34156 to 0.34726, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 352s - loss: 1.7178 - acc: 0.3431 - val_loss: 1.7300 - val_acc: 0.3473
Epoch 41/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7094 - acc: 0.3414Epoch 00040: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.7094 - acc: 0.3413 - val_loss: 1.7362 - val_acc: 0.3242
Epoch 42/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7134 - acc: 0.3346Epoch 00041: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.7137 - acc: 0.3344 - val_loss: 1.7259 - val_acc: 0.3340
Epoch 43/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7025 - acc: 0.3442Epoch 00042: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.7024 - acc: 0.3442 - val_loss: 1.7304 - val_acc: 0.3469
Epoch 44/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7032 - acc: 0.3465Epoch 00043: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.7028 - acc: 0.3467 - val_loss: 1.7166 - val_acc: 0.3419
Epoch 45/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.7004 - acc: 0.3409Epoch 00044: val_acc improved from 0.34726 to 0.35391, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 353s - loss: 1.7002 - acc: 0.3409 - val_loss: 1.7167 - val_acc: 0.3539
Epoch 46/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6930 - acc: 0.3423Epoch 00045: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.6933 - acc: 0.3420 - val_loss: 1.7362 - val_acc: 0.3416
Epoch 47/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6945 - acc: 0.3431Epoch 00046: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.6946 - acc: 0.3429 - val_loss: 1.7109 - val_acc: 0.3501
Epoch 48/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6924 - acc: 0.3454Epoch 00047: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6922 - acc: 0.3456 - val_loss: 1.7114 - val_acc: 0.3435
Epoch 49/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6879 - acc: 0.3517Epoch 00048: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.6879 - acc: 0.3516 - val_loss: 1.7233 - val_acc: 0.3387
Epoch 50/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6841 - acc: 0.3543Epoch 00049: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.6844 - acc: 0.3541 - val_loss: 1.7130 - val_acc: 0.3460
In [ ]:
model3.save('macro_class_cnn.h5')
In [170]:
plot_training_curves(history.history);
In [173]:
# let's train some more -- clearly still getting better
history2 = model3.fit(
    X_train_norm, Y_train,
    batch_size=64, nb_epoch=50, verbose=1,
    validation_data=(X_valid_norm, Y_valid),
    callbacks=[checkpoint])
Train on 9867 samples, validate on 3159 samples
Epoch 1/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6808 - acc: 0.3503Epoch 00000: val_acc did not improve
9867/9867 [==============================] - 367s - loss: 1.6812 - acc: 0.3503 - val_loss: 1.7217 - val_acc: 0.3406
Epoch 2/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6835 - acc: 0.3466Epoch 00001: val_acc did not improve
9867/9867 [==============================] - 356s - loss: 1.6832 - acc: 0.3465 - val_loss: 1.6938 - val_acc: 0.3539
Epoch 3/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6842 - acc: 0.3464Epoch 00002: val_acc improved from 0.35391 to 0.35613, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 355s - loss: 1.6843 - acc: 0.3466 - val_loss: 1.7102 - val_acc: 0.3561
Epoch 4/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6782 - acc: 0.3620Epoch 00003: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6783 - acc: 0.3619 - val_loss: 1.7195 - val_acc: 0.3412
Epoch 5/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6803 - acc: 0.3578Epoch 00004: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.6803 - acc: 0.3576 - val_loss: 1.7068 - val_acc: 0.3536
Epoch 6/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6731 - acc: 0.3580Epoch 00005: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.6729 - acc: 0.3580 - val_loss: 1.7430 - val_acc: 0.3204
Epoch 7/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6700 - acc: 0.3607Epoch 00006: val_acc did not improve
9867/9867 [==============================] - 354s - loss: 1.6703 - acc: 0.3606 - val_loss: 1.6957 - val_acc: 0.3558
Epoch 8/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6665 - acc: 0.3607Epoch 00007: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.6670 - acc: 0.3603 - val_loss: 1.7140 - val_acc: 0.3552
Epoch 9/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6665 - acc: 0.3604Epoch 00008: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6660 - acc: 0.3605 - val_loss: 1.6907 - val_acc: 0.3498
Epoch 10/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6655 - acc: 0.3589Epoch 00009: val_acc improved from 0.35613 to 0.36024, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 353s - loss: 1.6656 - acc: 0.3588 - val_loss: 1.6868 - val_acc: 0.3602
Epoch 11/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6585 - acc: 0.3727Epoch 00010: val_acc did not improve
9867/9867 [==============================] - 354s - loss: 1.6585 - acc: 0.3726 - val_loss: 1.6995 - val_acc: 0.3482
Epoch 12/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6540 - acc: 0.3669Epoch 00011: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.6543 - acc: 0.3668 - val_loss: 1.6846 - val_acc: 0.3587
Epoch 13/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6503 - acc: 0.3753Epoch 00012: val_acc did not improve
9867/9867 [==============================] - 356s - loss: 1.6506 - acc: 0.3754 - val_loss: 1.6921 - val_acc: 0.3536
Epoch 14/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6508 - acc: 0.3727Epoch 00013: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6512 - acc: 0.3726 - val_loss: 1.6877 - val_acc: 0.3504
Epoch 15/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6574 - acc: 0.3643Epoch 00014: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.6571 - acc: 0.3644 - val_loss: 1.6949 - val_acc: 0.3511
Epoch 16/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6472 - acc: 0.3684Epoch 00015: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6467 - acc: 0.3685 - val_loss: 1.6681 - val_acc: 0.3596
Epoch 17/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6417 - acc: 0.3717Epoch 00016: val_acc improved from 0.36024 to 0.36784, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 352s - loss: 1.6411 - acc: 0.3720 - val_loss: 1.6707 - val_acc: 0.3678
Epoch 18/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6476 - acc: 0.3686Epoch 00017: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6479 - acc: 0.3683 - val_loss: 1.6923 - val_acc: 0.3580
Epoch 19/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6412 - acc: 0.3745Epoch 00018: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6412 - acc: 0.3744 - val_loss: 1.6769 - val_acc: 0.3637
Epoch 20/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6310 - acc: 0.3738Epoch 00019: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.6310 - acc: 0.3739 - val_loss: 1.6767 - val_acc: 0.3466
Epoch 21/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6422 - acc: 0.3700Epoch 00020: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.6421 - acc: 0.3700 - val_loss: 1.6829 - val_acc: 0.3593
Epoch 22/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6328 - acc: 0.3705Epoch 00021: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.6326 - acc: 0.3706 - val_loss: 1.6848 - val_acc: 0.3602
Epoch 23/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6412 - acc: 0.3758Epoch 00022: val_acc did not improve
9867/9867 [==============================] - 354s - loss: 1.6411 - acc: 0.3759 - val_loss: 1.6720 - val_acc: 0.3621
Epoch 24/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6392 - acc: 0.3713Epoch 00023: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6390 - acc: 0.3714 - val_loss: 1.6756 - val_acc: 0.3561
Epoch 25/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6287 - acc: 0.3753Epoch 00024: val_acc did not improve
9867/9867 [==============================] - 350s - loss: 1.6288 - acc: 0.3753 - val_loss: 1.6857 - val_acc: 0.3549
Epoch 26/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6263 - acc: 0.3761Epoch 00025: val_acc did not improve
9867/9867 [==============================] - 351s - loss: 1.6260 - acc: 0.3761 - val_loss: 1.6763 - val_acc: 0.3659
Epoch 27/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6161 - acc: 0.3792Epoch 00026: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.6165 - acc: 0.3791 - val_loss: 1.6848 - val_acc: 0.3596
Epoch 28/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6143 - acc: 0.3843Epoch 00027: val_acc did not improve
9867/9867 [==============================] - 355s - loss: 1.6142 - acc: 0.3844 - val_loss: 1.6665 - val_acc: 0.3659
Epoch 29/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6232 - acc: 0.3771Epoch 00028: val_acc improved from 0.36784 to 0.36942, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 351s - loss: 1.6232 - acc: 0.3773 - val_loss: 1.6715 - val_acc: 0.3694
Epoch 30/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6174 - acc: 0.3765Epoch 00029: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6176 - acc: 0.3764 - val_loss: 1.6767 - val_acc: 0.3577
Epoch 31/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6215 - acc: 0.3826Epoch 00030: val_acc improved from 0.36942 to 0.36974, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 351s - loss: 1.6213 - acc: 0.3827 - val_loss: 1.6652 - val_acc: 0.3697
Epoch 32/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6242 - acc: 0.3780Epoch 00031: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6243 - acc: 0.3779 - val_loss: 1.6808 - val_acc: 0.3561
Epoch 33/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6132 - acc: 0.3855Epoch 00032: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.6136 - acc: 0.3852 - val_loss: 1.6900 - val_acc: 0.3549
Epoch 34/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6083 - acc: 0.3791Epoch 00033: val_acc did not improve
9867/9867 [==============================] - 352s - loss: 1.6083 - acc: 0.3789 - val_loss: 1.6688 - val_acc: 0.3514
Epoch 35/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6165 - acc: 0.3809Epoch 00034: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.6161 - acc: 0.3811 - val_loss: 1.6572 - val_acc: 0.3647
Epoch 36/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6150 - acc: 0.3849Epoch 00035: val_acc did not improve
9867/9867 [==============================] - 353s - loss: 1.6146 - acc: 0.3852 - val_loss: 1.6454 - val_acc: 0.3678
Epoch 37/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6202 - acc: 0.3766Epoch 00036: val_acc did not improve
9867/9867 [==============================] - 375s - loss: 1.6200 - acc: 0.3766 - val_loss: 1.6733 - val_acc: 0.3602
Epoch 38/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.5970 - acc: 0.3876Epoch 00037: val_acc did not improve
9867/9867 [==============================] - 382s - loss: 1.5970 - acc: 0.3877 - val_loss: 1.6659 - val_acc: 0.3587
Epoch 39/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.5945 - acc: 0.3885Epoch 00038: val_acc did not improve
9867/9867 [==============================] - 374s - loss: 1.5945 - acc: 0.3887 - val_loss: 1.6936 - val_acc: 0.3618
Epoch 40/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.5971 - acc: 0.3845Epoch 00039: val_acc did not improve
9867/9867 [==============================] - 378s - loss: 1.5975 - acc: 0.3845 - val_loss: 1.6654 - val_acc: 0.3663
Epoch 41/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6151 - acc: 0.3836Epoch 00040: val_acc improved from 0.36974 to 0.37417, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 369s - loss: 1.6151 - acc: 0.3833 - val_loss: 1.6565 - val_acc: 0.3742
Epoch 42/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6002 - acc: 0.3881Epoch 00041: val_acc did not improve
9867/9867 [==============================] - 358s - loss: 1.6004 - acc: 0.3879 - val_loss: 1.6648 - val_acc: 0.3644
Epoch 43/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6051 - acc: 0.3831Epoch 00042: val_acc did not improve
9867/9867 [==============================] - 362s - loss: 1.6051 - acc: 0.3833 - val_loss: 1.6772 - val_acc: 0.3587
Epoch 44/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.6054 - acc: 0.3800Epoch 00043: val_acc did not improve
9867/9867 [==============================] - 358s - loss: 1.6055 - acc: 0.3801 - val_loss: 1.6545 - val_acc: 0.3716
Epoch 45/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.5992 - acc: 0.3909Epoch 00044: val_acc improved from 0.37417 to 0.37607, saving model to macro_class_cnn_checkpoint.5
9867/9867 [==============================] - 352s - loss: 1.5992 - acc: 0.3908 - val_loss: 1.6501 - val_acc: 0.3761
Epoch 46/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.5940 - acc: 0.3858Epoch 00045: val_acc did not improve
9867/9867 [==============================] - 362s - loss: 1.5937 - acc: 0.3858 - val_loss: 1.6658 - val_acc: 0.3561
Epoch 47/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.5882 - acc: 0.3916Epoch 00046: val_acc did not improve
9867/9867 [==============================] - 371s - loss: 1.5884 - acc: 0.3916 - val_loss: 1.6982 - val_acc: 0.3302
Epoch 48/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.5881 - acc: 0.3895Epoch 00047: val_acc did not improve
9867/9867 [==============================] - 364s - loss: 1.5877 - acc: 0.3898 - val_loss: 1.6527 - val_acc: 0.3701
Epoch 49/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.5958 - acc: 0.3865Epoch 00048: val_acc did not improve
9867/9867 [==============================] - 358s - loss: 1.5954 - acc: 0.3866 - val_loss: 1.6918 - val_acc: 0.3545
Epoch 50/50
9856/9867 [============================>.] - ETA: 0s - loss: 1.5955 - acc: 0.3884Epoch 00049: val_acc did not improve
9867/9867 [==============================] - 358s - loss: 1.5952 - acc: 0.3885 - val_loss: 1.6797 - val_acc: 0.3634
In [180]:
from helpers import combine_histories
plot_training_curves(combine_histories(history.history, history2.history));

Diagnosing...

The model is starting to overfit. Let's try to diagnose what's going on, then decide what to do.

Now that we have 8 different classes, we can see how often they get confused for each other by looking at the aptly named confusion matrix. I would expect lots of confusion between coupe and sedan, and van and minivan, and suv and wagon.

In [190]:
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(labels, predictions,
                          classes,
                          normalize=False,
                          title="Confusion matrix",
                          cmap=plt.cm.Blues):
    """
    Plot a confusion matrix for predictions vs labels. 
    Both should be one-hot.
    
    Based on.
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html    
    """
    # convert from one-hot
    cat_labels = np.argmax(labels, axis=1)
    cat_predicts = np.argmax(predictions, axis=1)
    
    cm = confusion_matrix(cat_labels, cat_predicts)
    
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()    
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in it.product(range(cm.shape[0]), range(cm.shape[1])):
        if 0 < cm[i,j] < 1:
            val = "{:.2f}".format(cm[i,j])
        else:
            val = cm[i,j]
        plt.text(j, i, val,
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
In [182]:
# Get the predictions
predict_train = model3.predict(X_train_norm)
predict_valid = model3.predict(X_valid_norm)
predict_test = model3.predict(X_test_norm)
In [192]:
plot_confusion_matrix(Y_test, predict_test, macro_classes,
                      normalize=False,
                      title="Test confusion matrix");
Confusion matrix, without normalization
[[ 12  49   0   5  49 286   0   0]
 [ 10  98   0   2  46 390   0   0]
 [  0   0   0   0  20  77   0   0]
 [  1   4   0  31 162  99   0   0]
 [  4  10   0  13 290 270   0   0]
 [  2  35   0   0  92 675   0   0]
 [  1   5   0   2  43  62   0   0]
 [  1  22   0   4  43 244   0   0]]
In [193]:
plot_confusion_matrix(Y_train, predict_train, macro_classes,
                      title="Train confusion matrix")
Confusion matrix, without normalization
[[  78  156    0    3  105  914    0    0]
 [  17  593    0    3   82 1014    0    0]
 [   0    3    0    0   65  236    0    0]
 [   0    2    0  241  406  271    0    0]
 [   3   11    0   14 1182  624    0    0]
 [   0   21    0    2  179 2312    0    0]
 [   0    2    0    5  125  220    0    0]
 [   0   33    0    3  124  818    0    0]]
In [ ]:
plot_confusion_matrix(Y_train, predict_train, macro_classes,
                      
                      title="Train confusion matrix")

Well, it seems that most car types are classified as sedan. Not too surprising, especially given that sedans are overrepresented. It's starting to learn that SUVs and pickups are different from sedans, and occasionally manages to distinguish coupes from sedans.

So far, it doesn't use minivan, van, or wagon labels at all.

Things to check / do:

  • Is it correctly getting coupe images from the side, incorrectly from the front or back?
  • Look at class probabilities for some images, not just the maximal one
  • Count how many training images we have for each class. May want to oversample the low prob classes.
  • Try fine-tuning an off-the-shelf model.

To be continued

We'll pick up in 10b-cars-continued.ipynb...

In [ ]:
plt.bar()